diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..3f7bf37f4fb7aade9e03d03de5816d30c0e12903
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,20 @@
+Copyright (c) 2022-2022 dinesh reddy and others
+
+Permission is hereby granted, free of charge, to any person obtaining
+a copy of this software and associated documentation files (the
+"Software"), to deal in the Software without restriction, including
+without limitation the rights to use, copy, modify, merge, publish,
+distribute, sublicense, and/or sell copies of the Software, and to
+permit persons to whom the Software is furnished to do so, subject to
+the following conditions:
+
+The above copyright notice and this permission notice shall be
+included in all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
+EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
+MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
+NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
+LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
+OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
+WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
diff --git a/README.md b/README.md
index b76041e71ee0065351e30e85e53447b878f59965..006bc76eece809c527302d681447fda8e8757e10 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,13 @@
---
-title: WALT
-emoji: 👁
-colorFrom: purple
-colorTo: yellow
+title: WALT DEMO
+emoji: ⚡
+colorFrom: indigo
+colorTo: indigo
sdk: gradio
-sdk_version: 3.0.21
+sdk_version: 3.0.20
app_file: app.py
pinned: false
+license: mit
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..b423b9923ca306f94b85eca2564122699fbe2f3c
--- /dev/null
+++ b/app.py
@@ -0,0 +1,79 @@
+import numpy as np
+import torch
+import gradio as gr
+from infer import detections
+'''
+import os
+os.system("mkdir data")
+os.system("mkdir data/models")
+os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
+os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
+'''
+def walt_demo(input_img, confidence_threshold):
+ #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
+ if torch.cuda.is_available() == False:
+ device='cpu'
+ else:
+ device='cuda:0'
+ #detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
+ detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
+
+ count = 0
+ #img = detect_people.run_on_image(input_img)
+ output_img = detect.run_on_image(input_img)
+ #try:
+ #except:
+ # print("detecting on image failed")
+
+ return output_img
+
+description = """
+WALT Demo on WALT dataset. After watching and automatically learning for several days, this approach shows significant performance improvement in detecting and segmenting occluded people and vehicles, over human-supervised amodal approaches.
+
+
+
+
+
+
+
+"""
+title = "WALT:Watch And Learn 2D Amodal Representation using Time-lapse Imagery"
+article="""
+
+
+
+"""
+
+examples = [
+ ['demo/images/img_1.jpg',0.8],
+ ['demo/images/img_2.jpg',0.8],
+ ['demo/images/img_4.png',0.85],
+]
+
+'''
+import cv2
+filename='demo/images/img_1.jpg'
+img=cv2.imread(filename)
+img=walt_demo(img)
+cv2.imwrite(filename.replace('/images/','/results/'),img)
+cv2.imwrite('check.png',img)
+'''
+confidence_threshold = gr.Slider(minimum=0.3,
+ maximum=1.0,
+ step=0.01,
+ value=1.0,
+ label="Amodal Detection Confidence Threshold")
+inputs = [gr.Image(), confidence_threshold]
+demo = gr.Interface(walt_demo,
+ outputs="image",
+ inputs=inputs,
+ article=article,
+ title=title,
+ enable_queue=True,
+ examples=examples,
+ description=description)
+
+#demo.launch(server_name="0.0.0.0", server_port=7000)
+demo.launch()
+
+
diff --git a/configs/_base_/datasets/parking_instance.py b/configs/_base_/datasets/parking_instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc0635f09ccaa67ce63165bb018f0cf161fbba11
--- /dev/null
+++ b/configs/_base_/datasets/parking_instance.py
@@ -0,0 +1,48 @@
+dataset_type = 'ParkingDataset'
+data_root = 'data/parking/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d','gt_bboxes_3d_proj']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=1,
+ workers_per_gpu=1,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'GT_data/',
+ img_prefix=data_root + 'images/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root + 'GT_data/',
+ img_prefix=data_root + 'images/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root + 'GT_data/',
+ img_prefix=data_root + 'images/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox'])#, 'segm'])
diff --git a/configs/_base_/datasets/parking_instance_coco.py b/configs/_base_/datasets/parking_instance_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..85cee08021cabaca8cd6408c4e3efb2d7efae231
--- /dev/null
+++ b/configs/_base_/datasets/parking_instance_coco.py
@@ -0,0 +1,49 @@
+dataset_type = 'ParkingCocoDataset'
+data_root = 'data/parking/'
+data_root_test = 'data/parking_highres/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=6,
+ workers_per_gpu=6,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + 'GT_data/',
+ img_prefix=data_root + 'images/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root_test + 'GT_data/',
+ img_prefix=data_root_test + 'images',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root_test + 'GT_data/',
+ img_prefix=data_root_test + 'images',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
diff --git a/configs/_base_/datasets/people_real_coco.py b/configs/_base_/datasets/people_real_coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac50827efef253312971551ab55f1f26d72c7a7
--- /dev/null
+++ b/configs/_base_/datasets/people_real_coco.py
@@ -0,0 +1,49 @@
+dataset_type = 'WaltDataset'
+data_root = 'data/cwalt_train/'
+data_root_test = 'data/cwalt_test/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/',
+ img_prefix=data_root + '/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
diff --git a/configs/_base_/datasets/walt_people.py b/configs/_base_/datasets/walt_people.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ac50827efef253312971551ab55f1f26d72c7a7
--- /dev/null
+++ b/configs/_base_/datasets/walt_people.py
@@ -0,0 +1,49 @@
+dataset_type = 'WaltDataset'
+data_root = 'data/cwalt_train/'
+data_root_test = 'data/cwalt_test/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=8,
+ workers_per_gpu=8,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/',
+ img_prefix=data_root + '/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
diff --git a/configs/_base_/datasets/walt_vehicle.py b/configs/_base_/datasets/walt_vehicle.py
new file mode 100644
index 0000000000000000000000000000000000000000..466fa524d0f43b8684a01abe57188501787db8a4
--- /dev/null
+++ b/configs/_base_/datasets/walt_vehicle.py
@@ -0,0 +1,49 @@
+dataset_type = 'WaltDataset'
+data_root = 'data/cwalt_train/'
+data_root_test = 'data/cwalt_test/'
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+test_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(
+ type='MultiScaleFlipAug',
+ img_scale=(1333, 800),
+ flip=False,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ])
+]
+data = dict(
+ samples_per_gpu=5,
+ workers_per_gpu=5,
+ train=dict(
+ type=dataset_type,
+ ann_file=data_root + '/',
+ img_prefix=data_root + '/',
+ pipeline=train_pipeline),
+ val=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline),
+ test=dict(
+ type=dataset_type,
+ ann_file=data_root_test + '/',
+ img_prefix=data_root_test + '/',
+ pipeline=test_pipeline))
+evaluation = dict(metric=['bbox', 'segm'])
diff --git a/configs/_base_/default_runtime.py b/configs/_base_/default_runtime.py
new file mode 100644
index 0000000000000000000000000000000000000000..55097c5b242da66c9735c0b45cd84beefab487b1
--- /dev/null
+++ b/configs/_base_/default_runtime.py
@@ -0,0 +1,16 @@
+checkpoint_config = dict(interval=1)
+# yapf:disable
+log_config = dict(
+ interval=50,
+ hooks=[
+ dict(type='TextLoggerHook'),
+ # dict(type='TensorboardLoggerHook')
+ ])
+# yapf:enable
+custom_hooks = [dict(type='NumClassCheckHook')]
+
+dist_params = dict(backend='nccl')
+log_level = 'INFO'
+load_from = None
+resume_from = None
+workflow = [('train', 1)]
diff --git a/configs/_base_/models/mask_rcnn_swin_fpn.py b/configs/_base_/models/mask_rcnn_swin_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3d42197f4646cd9ecafac2095d3f8e079f0a729
--- /dev/null
+++ b/configs/_base_/models/mask_rcnn_swin_fpn.py
@@ -0,0 +1,127 @@
+# model settings
+model = dict(
+ type='MaskRCNN',
+ pretrained=None,
+ backbone=dict(
+ type='SwinTransformer',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ use_checkpoint=False),
+ neck=dict(
+ type='FPN',
+ in_channels=[96, 192, 384, 768],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ mask_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_head=dict(
+ type='FCNMaskHead',
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_pre=2000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_pre=1000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5)))
diff --git a/configs/_base_/models/occ_mask_rcnn_swin_fpn.py b/configs/_base_/models/occ_mask_rcnn_swin_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..80b7a8cf5f8e358f95723cb44ddb853c8f194f7e
--- /dev/null
+++ b/configs/_base_/models/occ_mask_rcnn_swin_fpn.py
@@ -0,0 +1,127 @@
+# model settings
+model = dict(
+ type='MaskRCNN',
+ pretrained=None,
+ backbone=dict(
+ type='SwinTransformer',
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ use_checkpoint=False),
+ neck=dict(
+ type='FPN',
+ in_channels=[96, 192, 384, 768],
+ out_channels=256,
+ num_outs=5),
+ rpn_head=dict(
+ type='RPNHead',
+ in_channels=256,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ roi_head=dict(
+ type='StandardRoIHead',
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='Shared2FCBBoxHead',
+ in_channels=256,
+ fc_out_channels=1024,
+ roi_feat_size=7,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
+ mask_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ mask_head=dict(
+ type='FCNOccMaskHead',
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
+ # model training and testing settings
+ train_cfg=dict(
+ rpn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.7,
+ neg_iou_thr=0.3,
+ min_pos_iou=0.3,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=256,
+ pos_fraction=0.5,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=False),
+ allowed_border=-1,
+ pos_weight=-1,
+ debug=False),
+ rpn_proposal=dict(
+ nms_pre=2000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ assigner=dict(
+ type='MaxIoUAssigner',
+ pos_iou_thr=0.5,
+ neg_iou_thr=0.5,
+ min_pos_iou=0.5,
+ match_low_quality=True,
+ ignore_iof_thr=-1),
+ sampler=dict(
+ type='RandomSampler',
+ num=512,
+ pos_fraction=0.25,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True),
+ mask_size=28,
+ pos_weight=-1,
+ debug=False)),
+ test_cfg=dict(
+ rpn=dict(
+ nms_pre=1000,
+ max_per_img=1000,
+ nms=dict(type='nms', iou_threshold=0.7),
+ min_bbox_size=0),
+ rcnn=dict(
+ score_thr=0.05,
+ nms=dict(type='nms', iou_threshold=0.5),
+ max_per_img=100,
+ mask_thr_binary=0.5)))
diff --git a/configs/_base_/schedules/schedule_1x.py b/configs/_base_/schedules/schedule_1x.py
new file mode 100644
index 0000000000000000000000000000000000000000..13b3783cbbe93b6c32bc415dc50f633dffa4aec7
--- /dev/null
+++ b/configs/_base_/schedules/schedule_1x.py
@@ -0,0 +1,11 @@
+# optimizer
+optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
+optimizer_config = dict(grad_clip=None)
+# learning policy
+lr_config = dict(
+ policy='step',
+ warmup='linear',
+ warmup_iters=500,
+ warmup_ratio=0.001,
+ step=[8, 11])
+runner = dict(type='EpochBasedRunner', max_epochs=12)
diff --git a/configs/walt/walt_people.py b/configs/walt/walt_people.py
new file mode 100644
index 0000000000000000000000000000000000000000..2dc45cd270a2cdb64f33a3a47b32eadd15a98c57
--- /dev/null
+++ b/configs/walt/walt_people.py
@@ -0,0 +1,80 @@
+_base_ = [
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
+ '../_base_/datasets/walt_people.py',
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+ backbone=dict(
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ ape=False,
+ drop_path_rate=0.1,
+ patch_norm=True,
+ use_checkpoint=False
+ ),
+ neck=dict(in_channels=[96, 192, 384, 768]))
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='AutoAugment',
+ policies=[
+ [
+ dict(type='Resize',
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+ (736, 1333), (768, 1333), (800, 1333)],
+ multiscale_mode='value',
+ keep_ratio=True)
+ ],
+ [
+ dict(type='Resize',
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+ multiscale_mode='value',
+ keep_ratio=True),
+ dict(type='RandomCrop',
+ crop_type='absolute_range',
+ crop_size=(384, 600),
+ allow_negative_crop=True),
+ dict(type='Resize',
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
+ (576, 1333), (608, 1333), (640, 1333),
+ (672, 1333), (704, 1333), (736, 1333),
+ (768, 1333), (800, 1333)],
+ multiscale_mode='value',
+ override=True,
+ keep_ratio=True)
+ ]
+ ]),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[8, 11])
+runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+
+# do not use mmdet version fp16
+fp16 = None
+optimizer_config = dict(
+ type="DistOptimizerHook",
+ update_interval=1,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ use_fp16=True,
+)
diff --git a/configs/walt/walt_vehicle.py b/configs/walt/walt_vehicle.py
new file mode 100644
index 0000000000000000000000000000000000000000..93c82d75f40543b1a900494e6b1921717dc7188e
--- /dev/null
+++ b/configs/walt/walt_vehicle.py
@@ -0,0 +1,80 @@
+_base_ = [
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
+ '../_base_/datasets/walt_vehicle.py',
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
+]
+
+model = dict(
+ backbone=dict(
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ ape=False,
+ drop_path_rate=0.1,
+ patch_norm=True,
+ use_checkpoint=False
+ ),
+ neck=dict(in_channels=[96, 192, 384, 768]))
+
+img_norm_cfg = dict(
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
+
+# augmentation strategy originates from DETR / Sparse RCNN
+train_pipeline = [
+ dict(type='LoadImageFromFile'),
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
+ dict(type='RandomFlip', flip_ratio=0.5),
+ dict(type='AutoAugment',
+ policies=[
+ [
+ dict(type='Resize',
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
+ (736, 1333), (768, 1333), (800, 1333)],
+ multiscale_mode='value',
+ keep_ratio=True)
+ ],
+ [
+ dict(type='Resize',
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
+ multiscale_mode='value',
+ keep_ratio=True),
+ dict(type='RandomCrop',
+ crop_type='absolute_range',
+ crop_size=(384, 600),
+ allow_negative_crop=True),
+ dict(type='Resize',
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
+ (576, 1333), (608, 1333), (640, 1333),
+ (672, 1333), (704, 1333), (736, 1333),
+ (768, 1333), (800, 1333)],
+ multiscale_mode='value',
+ override=True,
+ keep_ratio=True)
+ ]
+ ]),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='DefaultFormatBundle'),
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
+]
+data = dict(train=dict(pipeline=train_pipeline))
+
+optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
+ 'relative_position_bias_table': dict(decay_mult=0.),
+ 'norm': dict(decay_mult=0.)}))
+lr_config = dict(step=[8, 11])
+runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
+
+# do not use mmdet version fp16
+fp16 = None
+optimizer_config = dict(
+ type="DistOptimizerHook",
+ update_interval=1,
+ grad_clip=None,
+ coalesce=True,
+ bucket_size_mb=-1,
+ use_fp16=True,
+)
diff --git a/cwalt/CWALT.py b/cwalt/CWALT.py
new file mode 100644
index 0000000000000000000000000000000000000000..894578c1c75766cf27999dbb1fe64a4c4dcf4efb
--- /dev/null
+++ b/cwalt/CWALT.py
@@ -0,0 +1,161 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Tue Oct 19 19:14:47 2021
+
+@author: dinesh
+"""
+import glob
+from .utils import bb_intersection_over_union_unoccluded
+import numpy as np
+from PIL import Image
+import datetime
+import cv2
+import os
+from tqdm import tqdm
+
+
+def get_image(time, folder):
+ for week_loop in range(5):
+ try:
+ image = np.array(Image.open(folder+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
+ break
+ except:
+ continue
+ if image is None:
+ print('file not found')
+ return image
+
+def get_mask(segm, image):
+ poly = np.array(segm).reshape((int(len(segm)/2), 2))
+ mask = image.copy()*0
+ cv2.fillConvexPoly(mask, poly, (255, 255, 255))
+ return mask
+
+def get_unoccluded(indices, tracks_all):
+ unoccluded_indexes = []
+ unoccluded_index_all =[]
+ while 1:
+ unoccluded_clusters = []
+ len_unocc = len(unoccluded_indexes)
+ for ind in indices:
+ if ind in unoccluded_indexes:
+ continue
+ occ = False
+ for ind_compare in indices:
+ if ind_compare in unoccluded_indexes:
+ continue
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare]) > 0.01 and ind_compare != ind:
+ occ = True
+ if occ==False:
+ unoccluded_indexes.extend([ind])
+ unoccluded_clusters.extend([ind])
+ if len(unoccluded_indexes) == len_unocc and len_unocc != 0:
+ for ind in indices:
+ if ind not in unoccluded_indexes:
+ unoccluded_indexes.extend([ind])
+ unoccluded_clusters.extend([ind])
+
+ unoccluded_index_all.append(unoccluded_clusters)
+ if len(unoccluded_indexes) > len(indices)-5:
+ break
+ return unoccluded_index_all
+
+def primes(n): # simple sieve of multiples
+ odds = range(3, n+1, 2)
+ sieve = set(sum([list(range(q*q, n+1, q+q)) for q in odds], []))
+ return [2] + [p for p in odds if p not in sieve]
+
+def save_image(image_read, save_path, data, path):
+ tracks = data['tracks_all_unoccluded']
+ segmentations = data['segmentation_all_unoccluded']
+ timestamps = data['timestamps_final_unoccluded']
+
+ image = image_read.copy()
+ indices = np.random.randint(len(tracks),size=30)
+ prime_numbers = primes(1000)
+ unoccluded_index_all = get_unoccluded(indices, tracks)
+
+ mask_stacked = image*0
+ mask_stacked_all =[]
+ count = 0
+ time = datetime.datetime.now()
+
+ for l in indices:
+ try:
+ image_crop = get_image(timestamps[l], path)
+ except:
+ continue
+ try:
+ bb_left, bb_top, bb_width, bb_height, confidence = tracks[l]
+ except:
+ bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[l]
+ mask = get_mask(segmentations[l], image)
+
+ image[mask > 0] = image_crop[mask > 0]
+ mask[mask > 0] = 1
+ for count, mask_inc in enumerate(mask_stacked_all):
+ mask_stacked_all[count][cv2.bitwise_and(mask, mask_inc) > 0] = 2
+ mask_stacked_all.append(mask)
+ mask_stacked += mask
+ count = count+1
+
+ cv2.imwrite(save_path + '/images/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', image[:, :, ::-1])
+ cv2.imwrite(save_path + '/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', mask_stacked[:, :, ::-1]*30)
+ np.savez_compressed(save_path+'/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0], mask=mask_stacked_all)
+
+def CWALT_Generation(camera_name):
+ save_path_train = 'data/cwalt_train'
+ save_path_test = 'data/cwalt_test'
+
+ json_file_path = 'data/{}/{}.json'.format(camera_name,camera_name) # iii1/iii1_7_test.json' # './data.json'
+ path = 'data/' + camera_name
+
+ data = np.load(json_file_path + '.npz', allow_pickle=True)
+
+ ## slip data
+
+ data_train=dict()
+ data_test=dict()
+
+ split_index = int(len(data['timestamps_final_unoccluded'])*0.8)
+
+ data_train['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][0:split_index]
+ data_train['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][0:split_index]
+ data_train['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][0:split_index]
+
+ data_test['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][split_index:]
+ data_test['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][split_index:]
+ data_test['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][split_index:]
+
+ image_read = np.array(Image.open(path + '/T18-median_image.jpg'))
+ image_read = cv2.resize(image_read, (int(image_read.shape[1]/2), int(image_read.shape[0]/2)))
+
+ try:
+ os.mkdir(save_path_train)
+ except:
+ print(save_path_train)
+
+ try:
+ os.mkdir(save_path_train + '/images')
+ os.mkdir(save_path_train + '/Segmentation')
+ except:
+ print(save_path_train+ '/images')
+
+ try:
+ os.mkdir(save_path_test)
+ except:
+ print(save_path_test)
+
+ try:
+ os.mkdir(save_path_test + '/images')
+ os.mkdir(save_path_test + '/Segmentation')
+ except:
+ print(save_path_test+ '/images')
+
+ for loop in tqdm(range(3000), desc="Generating training CWALT Images "):
+ save_image(image_read, save_path_train, data_train, path)
+
+ for loop in tqdm(range(300), desc="Generating testing CWALT Images "):
+ save_image(image_read, save_path_test, data_test, path)
+
diff --git a/cwalt/Clip_WALT_Generate.py b/cwalt/Clip_WALT_Generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..09540a37a3a94600ac01a585f58b09270d070da7
--- /dev/null
+++ b/cwalt/Clip_WALT_Generate.py
@@ -0,0 +1,284 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri May 20 15:15:11 2022
+
+@author: dinesh
+"""
+
+from collections import OrderedDict
+from matplotlib import pyplot as plt
+from .utils import *
+import scipy.interpolate
+
+from scipy import interpolate
+from .clustering_utils import *
+import glob
+import cv2
+from PIL import Image
+
+
+import json
+import cv2
+
+import numpy as np
+from tqdm import tqdm
+
+
+def ignore_indexes(tracks_all, labels_all):
+ # get repeating bounding boxes
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x == y]
+ ignore_ind = []
+ for index, track in enumerate(tracks_all):
+ print('in ignore', index, len(tracks_all))
+ if index in ignore_ind:
+ continue
+
+ if labels_all[index] < 1 or labels_all[index] > 3:
+ ignore_ind.extend([index])
+
+ ind = get_indexes(track, tracks_all)
+ if len(ind) > 30:
+ ignore_ind.extend(ind)
+
+ return ignore_ind
+
+def repeated_indexes_old(tracks_all,ignore_ind, unoccluded_indexes=None):
+ # get repeating bounding boxes
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
+ repeat_ind = []
+ repeat_inds =[]
+ if unoccluded_indexes == None:
+ for index, track in enumerate(tracks_all):
+ if index in repeat_ind or index in ignore_ind:
+ continue
+ ind = get_indexes(track, tracks_all)
+ if len(ind) > 20:
+ repeat_ind.extend(ind)
+ repeat_inds.append([ind,track])
+ else:
+ for index in unoccluded_indexes:
+ if index in repeat_ind or index in ignore_ind:
+ continue
+ ind = get_indexes(tracks_all[index], tracks_all)
+ if len(ind) > 3:
+ repeat_ind.extend(ind)
+ repeat_inds.append([ind,tracks_all[index]])
+ return repeat_inds
+
+def get_unoccluded_instances(timestamps_final, tracks_all, ignore_ind=[], threshold = 0.01):
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
+ unoccluded_indexes = []
+ time_checked = []
+ stationary_obj = []
+ count =0
+
+ for time in tqdm(np.unique(timestamps_final), desc="Detecting Unocclued objects in Image "):
+ count += 1
+ if [time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond] in time_checked:
+ analyze_bb = []
+ for ind in unoccluded_indexes_time:
+ for ind_compare in same_time_instances:
+ iou = bb_intersection_over_union(tracks_all[ind], tracks_all[ind_compare])
+ if iou < 0.5 and iou > 0:
+ analyze_bb.extend([ind_compare])
+ if iou > 0.99:
+ stationary_obj.extend([str(ind_compare)+'+'+str(ind)])
+
+ for ind in analyze_bb:
+ occ = False
+ for ind_compare in same_time_instances:
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind:
+ occ = True
+ break
+ if occ == False:
+ unoccluded_indexes.extend([ind])
+ continue
+
+ same_time_instances = get_indexes(time,timestamps_final)
+ unoccluded_indexes_time = []
+
+ for ind in same_time_instances:
+ if tracks_all[ind][4] < 0.9 or ind in ignore_ind:# or ind != 1859:
+ continue
+ occ = False
+ for ind_compare in same_time_instances:
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind and tracks_all[ind_compare][4] < 0.5:
+ occ = True
+ break
+ if occ==False:
+ unoccluded_indexes.extend([ind])
+ unoccluded_indexes_time.extend([ind])
+ time_checked.append([time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond])
+ return unoccluded_indexes,stationary_obj
+
+def visualize_unoccluded_detection(timestamps_final,tracks_all,segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name, ignore_ind=[]):
+ tracks_final = []
+ tracks_final.append([])
+ try:
+ os.mkdir(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/')
+ except:
+ print('Unoccluded debugging exists')
+
+ for time in tqdm(np.unique(timestamps_final), desc="Visualizing Unocclued objects in Image "):
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
+ ind = get_indexes(time, timestamps_final)
+ image_unocc = False
+ for index in ind:
+ if index not in unoccluded_indexes:
+ continue
+ else:
+ image_unocc = True
+ break
+ if image_unocc == False:
+ continue
+
+ for week_loop in range(5):
+ try:
+ image = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
+ break
+ except:
+ continue
+
+ try:
+ mask = image*0
+ except:
+ print('image not found for ' + str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg' )
+ continue
+ image_original = image.copy()
+
+ for index in ind:
+ track = tracks_all[index]
+
+ if index in ignore_ind:
+ continue
+ if index not in unoccluded_indexes:
+ continue
+ try:
+ bb_left, bb_top, bb_width, bb_height, confidence, id = track
+ except:
+ bb_left, bb_top, bb_width, bb_height, confidence = track
+
+ if confidence > 0.6:
+ mask = poly_seg(image, segmentation_all[index])
+ cv2.imwrite(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/' + str(index)+'.png', mask[:, :, ::-1])
+
+def repeated_indexes(tracks_all,ignore_ind, repeat_count = 10, unoccluded_indexes=None):
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
+ repeat_ind = []
+ repeat_inds =[]
+ if unoccluded_indexes == None:
+ for index, track in enumerate(tracks_all):
+ if index in repeat_ind or index in ignore_ind:
+ continue
+
+ ind = get_indexes(track, tracks_all)
+ if len(ind) > repeat_count:
+ repeat_ind.extend(ind)
+ repeat_inds.append([ind,track])
+ else:
+ for index in unoccluded_indexes:
+ if index in repeat_ind or index in ignore_ind:
+ continue
+ ind = get_indexes(tracks_all[index], tracks_all)
+ if len(ind) > repeat_count:
+ repeat_ind.extend(ind)
+ repeat_inds.append([ind,tracks_all[index]])
+
+
+ return repeat_inds
+
+def poly_seg(image, segm):
+ poly = np.array(segm).reshape((int(len(segm)/2), 2))
+ overlay = image.copy()
+ alpha = 0.5
+ cv2.fillPoly(overlay, [poly], color=(255, 255, 0))
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
+ return image
+
+def visualize_unoccuded_clusters(repeat_inds, tracks, segmentation_all, timestamps_final, cwalt_data_path):
+ for index_, repeat_ind in enumerate(repeat_inds):
+ image = np.array(Image.open(cwalt_data_path+'/'+'T18-median_image.jpg'))
+ try:
+ os.mkdir(cwalt_data_path+ '/Cwalt_database/')
+ except:
+ print('folder exists')
+ try:
+ os.mkdir(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
+ except:
+ print(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
+
+ for i in repeat_ind[0]:
+ try:
+ bb_left, bb_top, bb_width, bb_height, confidence = tracks[i]#bbox
+ except:
+ bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[i]#bbox
+
+ cv2.rectangle(image,(int(bb_left), int(bb_top)),(int(bb_left+bb_width), int(bb_top+bb_height)),(0, 0, 255), 2)
+ time = timestamps_final[i]
+ for week_loop in range(5):
+ try:
+ image1 = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
+ break
+ except:
+ continue
+
+ crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/o_' + str(i) +'.jpg', crop[:, :, ::-1])
+ image1 = poly_seg(image1,segmentation_all[i])
+ crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/' + str(i)+'.jpg', crop[:, :, ::-1])
+ if index_ > 100:
+ break
+
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'.jpg', image[:, :, ::-1])
+
+def Get_unoccluded_objects(camera_name, debug = False, scale=True):
+ cwalt_data_path = 'data/' + camera_name
+ data_folder = cwalt_data_path
+ json_file_path = cwalt_data_path + '/' + camera_name + '.json'
+
+ with open(json_file_path, 'r') as j:
+ annotations = json.loads(j.read())
+
+ tracks_all = [parse_bbox(anno['bbox']) for anno in annotations]
+ segmentation_all = [parse_bbox(anno['segmentation']) for anno in annotations]
+ labels_all = [anno['label_id'] for anno in annotations]
+ timestamps_final = [parse(anno['time']) for anno in annotations]
+
+ if scale ==True:
+ scale_factor = 2
+ tracks_all_numpy = np.array(tracks_all)
+ tracks_all_numpy[:,:4] = np.array(tracks_all)[:,:4]/scale_factor
+ tracks_all = tracks_all_numpy.tolist()
+
+ segmentation_all_scaled = []
+ for list_loop in segmentation_all:
+ segmentation_all_scaled.append((np.floor_divide(np.array(list_loop),scale_factor)).tolist())
+ segmentation_all = segmentation_all_scaled
+
+ if debug == True:
+ timestamps_final = timestamps_final[:1000]
+ labels_all = labels_all[:1000]
+ segmentation_all = segmentation_all[:1000]
+ tracks_all = tracks_all[:1000]
+
+ unoccluded_indexes, stationary = get_unoccluded_instances(timestamps_final, tracks_all, threshold = 0.05)
+ if debug == True:
+ visualize_unoccluded_detection(timestamps_final, tracks_all, segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name)
+
+ tracks_all_unoccluded = [tracks_all[i] for i in unoccluded_indexes]
+ segmentation_all_unoccluded = [segmentation_all[i] for i in unoccluded_indexes]
+ labels_all_unoccluded = [labels_all[i] for i in unoccluded_indexes]
+ timestamps_final_unoccluded = [timestamps_final[i] for i in unoccluded_indexes]
+ np.savez(json_file_path,tracks_all_unoccluded=tracks_all_unoccluded, segmentation_all_unoccluded=segmentation_all_unoccluded, labels_all_unoccluded=labels_all_unoccluded, timestamps_final_unoccluded=timestamps_final_unoccluded )
+
+ if debug == True:
+ repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=1)
+ visualize_unoccuded_clusters(repeat_inds_clusters, tracks_all_unoccluded, segmentation_all_unoccluded, timestamps_final_unoccluded, cwalt_data_path)
+ else:
+ repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=10)
+
+ np.savez(json_file_path + '_clubbed', repeat_inds=repeat_inds_clusters)
+ np.savez(json_file_path + '_stationary', stationary=stationary)
+
diff --git a/cwalt/Download_Detections.py b/cwalt/Download_Detections.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5d368498ef506abd18169b7e128759c929ad565
--- /dev/null
+++ b/cwalt/Download_Detections.py
@@ -0,0 +1,28 @@
+import json
+from psycopg2.extras import RealDictCursor
+#import cv2
+import psycopg2
+import cv2
+
+
+CONNECTION = "postgres://postgres:"
+
+conn = psycopg2.connect(CONNECTION)
+cursor = conn.cursor(cursor_factory=RealDictCursor)
+
+
+def get_sample():
+ camera_name, camera_id = 'cam2', 4
+
+ print('Executing SQL command')
+
+ cursor.execute("SELECT * FROM annotations WHERE camera_id = {} and time >='2021-05-01 00:00:00' and time <='2021-05-07 23:59:50' and label_id in (1,2)".format(camera_id))
+
+ print('Dumping to json')
+ annotations = json.dumps(cursor.fetchall(), indent=2, default=str)
+ wjdata = json.loads(annotations)
+ with open('{}_{}_test.json'.format(camera_name, camera_id), 'w') as f:
+ json.dump(wjdata, f)
+ print('Done dumping to json')
+
+get_sample()
diff --git a/cwalt/clustering_utils.py b/cwalt/clustering_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..7463bfce84ae1c1089d9cf2a0e97de8e7397ce7a
--- /dev/null
+++ b/cwalt/clustering_utils.py
@@ -0,0 +1,132 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri May 20 15:18:20 2022
+
+@author: dinesh
+"""
+
+# 0 - Import related libraries
+
+import urllib
+import zipfile
+import os
+import scipy.io
+import math
+import numpy as np
+import matplotlib.pyplot as plt
+import seaborn as sns
+
+from scipy.spatial.distance import directed_hausdorff
+from sklearn.cluster import DBSCAN
+from sklearn.metrics.pairwise import pairwise_distances
+import scipy.spatial.distance
+
+from .kmedoid import kMedoids # kMedoids code is adapted from https://github.com/letiantian/kmedoids
+
+# Some visualization stuff, not so important
+# sns.set()
+plt.rcParams['figure.figsize'] = (12, 12)
+
+# Utility Functions
+
+color_lst = plt.rcParams['axes.prop_cycle'].by_key()['color']
+color_lst.extend(['firebrick', 'olive', 'indigo', 'khaki', 'teal', 'saddlebrown',
+ 'skyblue', 'coral', 'darkorange', 'lime', 'darkorchid', 'dimgray'])
+
+
+def plot_cluster(image, traj_lst, cluster_lst):
+ '''
+ Plots given trajectories with a color that is specific for every trajectory's own cluster index.
+ Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
+ '''
+ cluster_count = np.max(cluster_lst) + 1
+
+ for traj, cluster in zip(traj_lst, cluster_lst):
+
+ # if cluster == -1:
+ # # Means it it a noisy trajectory, paint it black
+ # plt.plot(traj[:, 0], traj[:, 1], c='k', linestyle='dashed')
+ #
+ # else:
+ plt.plot(traj[:, 0], traj[:, 1], c=color_lst[cluster % len(color_lst)])
+
+ plt.imshow(image)
+ # plt.show()
+ plt.axis('off')
+ plt.savefig('trajectory.png', bbox_inches='tight')
+ plt.show()
+
+
+# 3 - Distance matrix
+
+def hausdorff( u, v):
+ d = max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0])
+ return d
+
+
+def build_distance_matrix(traj_lst):
+ # 2 - Trajectory segmentation
+
+ print('Running trajectory segmentation...')
+ degree_threshold = 5
+
+ for traj_index, traj in enumerate(traj_lst):
+
+ hold_index_lst = []
+ previous_azimuth = 1000
+
+ for point_index, point in enumerate(traj[:-1]):
+ next_point = traj[point_index + 1]
+ diff_vector = next_point - point
+ azimuth = (math.degrees(math.atan2(*diff_vector)) + 360) % 360
+
+ if abs(azimuth - previous_azimuth) > degree_threshold:
+ hold_index_lst.append(point_index)
+ previous_azimuth = azimuth
+ hold_index_lst.append(traj.shape[0] - 1) # Last point of trajectory is always added
+
+ traj_lst[traj_index] = traj[hold_index_lst, :]
+
+ print('Building distance matrix...')
+ traj_count = len(traj_lst)
+ D = np.zeros((traj_count, traj_count))
+
+ # This may take a while
+ for i in range(traj_count):
+ if i % 20 == 0:
+ print(i)
+ for j in range(i + 1, traj_count):
+ distance = hausdorff(traj_lst[i], traj_lst[j])
+ D[i, j] = distance
+ D[j, i] = distance
+
+ return D
+
+
+def run_kmedoids(image, traj_lst, D):
+ # 4 - Different clustering methods
+
+ # 4.1 - kmedoids
+
+ traj_count = len(traj_lst)
+
+ k = 3 # The number of clusters
+ medoid_center_lst, cluster2index_lst = kMedoids(D, k)
+
+ cluster_lst = np.empty((traj_count,), dtype=int)
+
+ for cluster in cluster2index_lst:
+ cluster_lst[cluster2index_lst[cluster]] = cluster
+
+ plot_cluster(image, traj_lst, cluster_lst)
+
+
+def run_dbscan(image, traj_lst, D):
+ mdl = DBSCAN(eps=400, min_samples=10)
+ cluster_lst = mdl.fit_predict(D)
+
+ plot_cluster(image, traj_lst, cluster_lst)
+
+
+
diff --git a/cwalt/kmedoid.py b/cwalt/kmedoid.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a04839cf1fd9e8d1bf56872f1c67d0bd7005cb9
--- /dev/null
+++ b/cwalt/kmedoid.py
@@ -0,0 +1,55 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri May 20 15:18:56 2022
+
+@author: dinesh
+"""
+
+import numpy as np
+import math
+
+def kMedoids(D, k, tmax=100):
+ # determine dimensions of distance matrix D
+ m, n = D.shape
+
+ np.fill_diagonal(D, math.inf)
+
+ if k > n:
+ raise Exception('too many medoids')
+ # randomly initialize an array of k medoid indices
+ M = np.arange(n)
+ np.random.shuffle(M)
+ M = np.sort(M[:k])
+
+ # create a copy of the array of medoid indices
+ Mnew = np.copy(M)
+
+ # initialize a dictionary to represent clusters
+ C = {}
+ for t in range(tmax):
+ # determine clusters, i. e. arrays of data indices
+ J = np.argmin(D[:,M], axis=1)
+
+ for kappa in range(k):
+ C[kappa] = np.where(J==kappa)[0]
+ # update cluster medoids
+ for kappa in range(k):
+ J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1)
+ j = np.argmin(J)
+ Mnew[kappa] = C[kappa][j]
+ np.sort(Mnew)
+ # check for convergence
+ if np.array_equal(M, Mnew):
+ break
+ M = np.copy(Mnew)
+ else:
+ # final update of cluster memberships
+ J = np.argmin(D[:,M], axis=1)
+ for kappa in range(k):
+ C[kappa] = np.where(J==kappa)[0]
+
+ np.fill_diagonal(D, 0)
+
+ # return results
+ return M, C
\ No newline at end of file
diff --git a/cwalt/utils.py b/cwalt/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..57f8e05a01cb4895dd95a4175f96a35974ee3ea3
--- /dev/null
+++ b/cwalt/utils.py
@@ -0,0 +1,168 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Fri May 20 15:16:56 2022
+
+@author: dinesh
+"""
+
+import json
+import cv2
+from PIL import Image
+import numpy as np
+from dateutil.parser import parse
+
+def bb_intersection_over_union(box1, box2):
+ #print(box1, box2)
+ boxA = box1.copy()
+ boxB = box2.copy()
+ boxA[2] = boxA[0]+boxA[2]
+ boxA[3] = boxA[1]+boxA[3]
+ boxB[2] = boxB[0]+boxB[2]
+ boxB[3] = boxB[1]+boxB[3]
+ # determine the (x, y)-coordinates of the intersection rectangle
+ xA = max(boxA[0], boxB[0])
+ yA = max(boxA[1], boxB[1])
+ xB = min(boxA[2], boxB[2])
+ yB = min(boxA[3], boxB[3])
+
+ # compute the area of intersection rectangle
+ interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
+
+ if interArea == 0:
+ return 0
+ # compute the area of both the prediction and ground-truth
+ # rectangles
+ boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
+ boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
+
+ # compute the intersection over union by taking the intersection
+ # area and dividing it by the sum of prediction + ground-truth
+ # areas - the interesection area
+ iou = interArea / float(boxAArea + boxBArea - interArea)
+ return iou
+
+def bb_intersection_over_union_unoccluded(box1, box2, threshold=0.01):
+ #print(box1, box2)
+ boxA = box1.copy()
+ boxB = box2.copy()
+ boxA[2] = boxA[0]+boxA[2]
+ boxA[3] = boxA[1]+boxA[3]
+ boxB[2] = boxB[0]+boxB[2]
+ boxB[3] = boxB[1]+boxB[3]
+ # determine the (x, y)-coordinates of the intersection rectangle
+ xA = max(boxA[0], boxB[0])
+ yA = max(boxA[1], boxB[1])
+ xB = min(boxA[2], boxB[2])
+ yB = min(boxA[3], boxB[3])
+
+ # compute the area of intersection rectangle
+ interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
+
+ if interArea == 0:
+ return 0
+ # compute the area of both the prediction and ground-truth
+ # rectangles
+ boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
+ boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
+
+ # compute the intersection over union by taking the intersection
+ # area and dividing it by the sum of prediction + ground-truth
+ # areas - the interesection area
+ iou = interArea / float(boxAArea + boxBArea - interArea)
+
+ #print(iou)
+ # return the intersection over union value
+ occlusion = False
+ if iou > threshold and iou < 1:
+ #print(boxA[3], boxB[3], boxB[1])
+ if boxA[3] < boxB[3]:# and boxA[3] > boxB[1]:
+ if boxB[2] > boxA[0]:# and boxB[2] < boxA[2]:
+ #print('first', (boxB[2] - boxA[0])/(boxA[2] - boxA[0]))
+ if (min(boxB[2],boxA[2]) - boxA[0])/(boxA[2] - boxA[0]) > threshold:
+ occlusion = True
+
+ if boxB[0] < boxA[2]: # boxB[0] > boxA[0] and
+ #print('second', (boxA[2] - boxB[0])/(boxA[2] - boxA[0]))
+ if (boxA[2] - max(boxB[0],boxA[0]))/(boxA[2] - boxA[0]) > threshold:
+ occlusion = True
+ if occlusion == False:
+ iou = iou*0
+ #asas
+ # asas
+ #iou = 0.9 #iou*0
+ #print(box1, box2, iou, occlusion)
+ return iou
+def draw_tracks(image, tracks):
+ """
+ Draw on input image.
+
+ Args:
+ image (numpy.ndarray): image
+ tracks (list): list of tracks to be drawn on the image.
+
+ Returns:
+ numpy.ndarray: image with the track-ids drawn on it.
+ """
+
+ for trk in tracks:
+
+ trk_id = trk[1]
+ xmin = trk[2]
+ ymin = trk[3]
+ width = trk[4]
+ height = trk[5]
+
+ xcentroid, ycentroid = int(xmin + 0.5*width), int(ymin + 0.5*height)
+
+ text = "ID {}".format(trk_id)
+
+ cv2.putText(image, text, (xcentroid - 10, ycentroid - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
+ cv2.circle(image, (xcentroid, ycentroid), 4, (0, 255, 0), -1)
+
+ return image
+
+
+def draw_bboxes(image, tracks):
+ """
+ Draw the bounding boxes about detected objects in the image.
+
+ Args:
+ image (numpy.ndarray): Image or video frame.
+ bboxes (numpy.ndarray): Bounding boxes pixel coordinates as (xmin, ymin, width, height)
+ confidences (numpy.ndarray): Detection confidence or detection probability.
+ class_ids (numpy.ndarray): Array containing class ids (aka label ids) of each detected object.
+
+ Returns:
+ numpy.ndarray: image with the bounding boxes drawn on it.
+ """
+
+ for trk in tracks:
+ xmin = int(trk[2])
+ ymin = int(trk[3])
+ width = int(trk[4])
+ height = int(trk[5])
+ clr = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
+ cv2.rectangle(image, (xmin, ymin), (xmin + width, ymin + height), clr, 2)
+
+ return image
+
+
+def num(v):
+ number_as_float = float(v)
+ number_as_int = int(number_as_float)
+ return number_as_int if number_as_float == number_as_int else number_as_float
+
+
+def parse_bbox(bbox_str):
+ bbox_list = bbox_str.strip('{').strip('}').split(',')
+ bbox_list = [num(elem) for elem in bbox_list]
+ return bbox_list
+
+def parse_seg(bbox_str):
+ bbox_list = bbox_str.strip('{').strip('}').split(',')
+ bbox_list = [num(elem) for elem in bbox_list]
+ ret = bbox_list # []
+ # for i in range(0, len(bbox_list) - 1, 2):
+ # ret.append((bbox_list[i], bbox_list[i + 1]))
+ return ret
diff --git a/cwalt_generate.py b/cwalt_generate.py
new file mode 100644
index 0000000000000000000000000000000000000000..18e16c8912d8fe879bfe82ad47d4c04abe44766e
--- /dev/null
+++ b/cwalt_generate.py
@@ -0,0 +1,14 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+"""
+Created on Sat Jun 4 16:55:58 2022
+
+@author: dinesh
+"""
+from cwalt.CWALT import CWALT_Generation
+from cwalt.Clip_WALT_Generate import Get_unoccluded_objects
+
+if __name__ == '__main__':
+ camera_name = 'cam2'
+ Get_unoccluded_objects(camera_name)
+ CWALT_Generation(camera_name)
diff --git a/docker/Dockerfile b/docker/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..fcbef9057621342d69644598f2fb865aee80001f
--- /dev/null
+++ b/docker/Dockerfile
@@ -0,0 +1,52 @@
+ARG PYTORCH="1.9.0"
+ARG CUDA="11.1"
+ARG CUDNN="8"
+
+FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
+
+ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
+ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
+ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
+RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
+RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
+RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
+ && apt-get clean \
+ && rm -rf /var/lib/apt/lists/*
+
+# Install MMCV
+#RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
+# -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
+RUN pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
+# Install MMDetection
+RUN conda clean --all
+RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
+WORKDIR /mmdetection
+ENV FORCE_CUDA="1"
+RUN cd /mmdetection && git checkout 7bd39044f35aec4b90dd797b965777541a8678ff
+RUN pip install -r requirements/build.txt
+RUN pip install --no-cache-dir -e .
+RUN apt-get update
+RUN apt-get install -y vim
+RUN pip uninstall -y pycocotools
+RUN pip install mmpycocotools timm scikit-image imagesize
+
+
+# make sure we don't overwrite some existing directory called "apex"
+WORKDIR /tmp/unique_for_apex
+# uninstall Apex if present, twice to make absolutely sure :)
+RUN pip uninstall -y apex || :
+RUN pip uninstall -y apex || :
+# SHA is something the user can touch to force recreation of this Docker layer,
+# and therefore force cloning of the latest version of Apex
+RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
+WORKDIR /tmp/unique_for_apex/apex
+RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
+RUN pip install seaborn sklearn imantics gradio
+WORKDIR /code
+ENTRYPOINT ["python", "app.py"]
+
+#RUN git clone https://github.com/NVIDIA/apex
+#RUN cd apex
+#RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
+#RUN pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
+
diff --git a/github_vis/cwalt.gif b/github_vis/cwalt.gif
new file mode 100644
index 0000000000000000000000000000000000000000..28ee6c5b705630b683ed2180c64bc8ede050731e
Binary files /dev/null and b/github_vis/cwalt.gif differ
diff --git a/github_vis/vis_cars.gif b/github_vis/vis_cars.gif
new file mode 100644
index 0000000000000000000000000000000000000000..c812b77f346acc25b484f0760e36e7fcc16d7cfc
Binary files /dev/null and b/github_vis/vis_cars.gif differ
diff --git a/github_vis/vis_people.gif b/github_vis/vis_people.gif
new file mode 100644
index 0000000000000000000000000000000000000000..2fa6b620a42badab4871b140b0f1f7da3a7ba305
Binary files /dev/null and b/github_vis/vis_people.gif differ
diff --git a/infer.py b/infer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee71873a0955453fd137947678a0e8b4a1423b08
--- /dev/null
+++ b/infer.py
@@ -0,0 +1,118 @@
+from argparse import ArgumentParser
+
+from mmdet.apis import inference_detector, init_detector, show_result_pyplot
+from mmdet.core.mask.utils import encode_mask_results
+import numpy as np
+import mmcv
+import torch
+from imantics import Polygons, Mask
+import json
+import os
+import cv2, glob
+
+class detections():
+ def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth', threshold=0.85):
+ self.model = init_detector(cfg_path, model_path, device=device)
+ self.all_preds = []
+ self.all_scores = []
+ self.index = []
+ self.score_thr = threshold
+ self.result = []
+ self.record_dict = {'model': cfg_path,'results': []}
+ self.detect_count = []
+
+
+ def run_on_image(self, image):
+ self.result = inference_detector(self.model, image)
+ image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
+ return image_labelled
+
+ def process_output(self, count):
+ result = self.result
+ infer_result = {'url': count,
+ 'boxes': [],
+ 'scores': [],
+ 'keypoints': [],
+ 'segmentation': [],
+ 'label_ids': [],
+ 'track': [],
+ 'labels': []}
+
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ #segm_result = encode_mask_results(segm_result)
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ bboxes = np.vstack(bbox_result)
+ labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
+
+ labels = np.concatenate(labels)
+ segms = None
+ if segm_result is not None and len(labels) > 0: # non empty
+ segms = mmcv.concat_list(segm_result)
+ if isinstance(segms[0], torch.Tensor):
+ segms = torch.stack(segms, dim=0).detach().cpu().numpy()
+ else:
+ segms = np.stack(segms, axis=0)
+
+ for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
+ if bbox[-1].item() <0.3:
+ continue
+ box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
+ polygons = Mask(segm).polygons()
+
+ infer_result['boxes'].append(box)
+ infer_result['segmentation'].append(polygons.segmentation)
+ infer_result['scores'].append(bbox[-1].item())
+ infer_result['labels'].append(self.model.CLASSES[label])
+ infer_result['label_ids'].append(label)
+ self.record_dict['results'].append(infer_result)
+ self.detect_count = labels
+
+ def write_json(self, filename):
+ with open(filename + '.json', 'w') as f:
+ json.dump(self.record_dict, f)
+
+
+def main():
+ if torch.cuda.is_available() == False:
+ device='cpu'
+ else:
+ device='cuda:0'
+ detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
+ detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth')
+ filenames = sorted(glob.glob('demo/images/*'))
+ count = 0
+ for filename in filenames:
+ img=cv2.imread(filename)
+ try:
+ img = detect_people.run_on_image(img)
+ img = detect.run_on_image(img)
+ except:
+ continue
+ count=count+1
+
+ try:
+ import os
+ os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
+ os.mkdirs(os.path.dirname(filename))
+ except:
+ print('done')
+ cv2.imwrite(filename.replace('demo','demo/results/'),img)
+ if count == 30000:
+ break
+ try:
+ detect.process_output(count)
+ except:
+ continue
+ '''
+
+ np.savez('FC', a= detect.record_dict)
+ with open('check.json', 'w') as f:
+ json.dump(detect.record_dict, f)
+ detect.write_json('seq3')
+ asas
+ detect.process_output(0)
+ '''
+if __name__ == "__main__":
+ main()
diff --git a/mmcv_custom/__init__.py b/mmcv_custom/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e0e39b03e2a149c33c372472b2b814a872ec55c
--- /dev/null
+++ b/mmcv_custom/__init__.py
@@ -0,0 +1,5 @@
+# -*- coding: utf-8 -*-
+
+from .checkpoint import load_checkpoint
+
+__all__ = ['load_checkpoint']
diff --git a/mmcv_custom/checkpoint.py b/mmcv_custom/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..51322c1c3802f357481065a70dc5152469d80eb8
--- /dev/null
+++ b/mmcv_custom/checkpoint.py
@@ -0,0 +1,500 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import io
+import os
+import os.path as osp
+import pkgutil
+import time
+import warnings
+from collections import OrderedDict
+from importlib import import_module
+from tempfile import TemporaryDirectory
+
+import torch
+import torchvision
+from torch.optim import Optimizer
+from torch.utils import model_zoo
+from torch.nn import functional as F
+
+import mmcv
+from mmcv.fileio import FileClient
+from mmcv.fileio import load as load_file
+from mmcv.parallel import is_module_wrapper
+from mmcv.utils import mkdir_or_exist
+from mmcv.runner import get_dist_info
+
+ENV_MMCV_HOME = 'MMCV_HOME'
+ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
+DEFAULT_CACHE_DIR = '~/.cache'
+
+
+def _get_mmcv_home():
+ mmcv_home = os.path.expanduser(
+ os.getenv(
+ ENV_MMCV_HOME,
+ os.path.join(
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
+
+ mkdir_or_exist(mmcv_home)
+ return mmcv_home
+
+
+def load_state_dict(module, state_dict, strict=False, logger=None):
+ """Load state_dict to a module.
+
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
+ Default value for ``strict`` is set to ``False`` and the message for
+ param mismatch will be shown even if strict is False.
+
+ Args:
+ module (Module): Module that receives the state_dict.
+ state_dict (OrderedDict): Weights.
+ strict (bool): whether to strictly enforce that the keys
+ in :attr:`state_dict` match the keys returned by this module's
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
+ message. If not specified, print function will be used.
+ """
+ unexpected_keys = []
+ all_missing_keys = []
+ err_msg = []
+
+ metadata = getattr(state_dict, '_metadata', None)
+ state_dict = state_dict.copy()
+ if metadata is not None:
+ state_dict._metadata = metadata
+
+ # use _load_from_state_dict to enable checkpoint version control
+ def load(module, prefix=''):
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+ local_metadata = {} if metadata is None else metadata.get(
+ prefix[:-1], {})
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
+ all_missing_keys, unexpected_keys,
+ err_msg)
+ for name, child in module._modules.items():
+ if child is not None:
+ load(child, prefix + name + '.')
+
+ load(module)
+ load = None # break load->load reference cycle
+
+ # ignore "num_batches_tracked" of BN layers
+ missing_keys = [
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
+ ]
+
+ if unexpected_keys:
+ err_msg.append('unexpected key in source '
+ f'state_dict: {", ".join(unexpected_keys)}\n')
+ if missing_keys:
+ err_msg.append(
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
+
+ rank, _ = get_dist_info()
+ if len(err_msg) > 0 and rank == 0:
+ err_msg.insert(
+ 0, 'The model and loaded state dict do not match exactly\n')
+ err_msg = '\n'.join(err_msg)
+ if strict:
+ raise RuntimeError(err_msg)
+ elif logger is not None:
+ logger.warning(err_msg)
+ else:
+ print(err_msg)
+
+
+def load_url_dist(url, model_dir=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
+ return checkpoint
+
+
+def load_pavimodel_dist(model_path, map_location=None):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ try:
+ from pavi import modelcloud
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ if rank == 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ model = modelcloud.get(model_path)
+ with TemporaryDirectory() as tmp_dir:
+ downloaded_file = osp.join(tmp_dir, model.name)
+ model.download(downloaded_file)
+ checkpoint = torch.load(
+ downloaded_file, map_location=map_location)
+ return checkpoint
+
+
+def load_fileclient_dist(filename, backend, map_location):
+ """In distributed setting, this function only download checkpoint at local
+ rank 0."""
+ rank, world_size = get_dist_info()
+ rank = int(os.environ.get('LOCAL_RANK', rank))
+ allowed_backends = ['ceph']
+ if backend not in allowed_backends:
+ raise ValueError(f'Load from Backend {backend} is not supported.')
+ if rank == 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ if world_size > 1:
+ torch.distributed.barrier()
+ if rank > 0:
+ fileclient = FileClient(backend=backend)
+ buffer = io.BytesIO(fileclient.get(filename))
+ checkpoint = torch.load(buffer, map_location=map_location)
+ return checkpoint
+
+
+def get_torchvision_models():
+ model_urls = dict()
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
+ if ispkg:
+ continue
+ _zoo = import_module(f'torchvision.models.{name}')
+ if hasattr(_zoo, 'model_urls'):
+ _urls = getattr(_zoo, 'model_urls')
+ model_urls.update(_urls)
+ return model_urls
+
+
+def get_external_models():
+ mmcv_home = _get_mmcv_home()
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
+ default_urls = load_file(default_json_path)
+ assert isinstance(default_urls, dict)
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
+ if osp.exists(external_json_path):
+ external_urls = load_file(external_json_path)
+ assert isinstance(external_urls, dict)
+ default_urls.update(external_urls)
+
+ return default_urls
+
+
+def get_mmcls_models():
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
+ mmcls_urls = load_file(mmcls_json_path)
+
+ return mmcls_urls
+
+
+def get_deprecated_model_names():
+ deprecate_json_path = osp.join(mmcv.__path__[0],
+ 'model_zoo/deprecated.json')
+ deprecate_urls = load_file(deprecate_json_path)
+ assert isinstance(deprecate_urls, dict)
+
+ return deprecate_urls
+
+
+def _process_mmcls_checkpoint(checkpoint):
+ state_dict = checkpoint['state_dict']
+ new_state_dict = OrderedDict()
+ for k, v in state_dict.items():
+ if k.startswith('backbone.'):
+ new_state_dict[k[9:]] = v
+ new_checkpoint = dict(state_dict=new_state_dict)
+
+ return new_checkpoint
+
+
+def _load_checkpoint(filename, map_location=None):
+ """Load checkpoint from somewhere (modelzoo, file, url).
+
+ Args:
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
+
+ Returns:
+ dict | OrderedDict: The loaded checkpoint. It can be either an
+ OrderedDict storing model weights or a dict containing other
+ information, which depends on the checkpoint.
+ """
+ if filename.startswith('modelzoo://'):
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
+ 'use "torchvision://" instead')
+ model_urls = get_torchvision_models()
+ model_name = filename[11:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('torchvision://'):
+ model_urls = get_torchvision_models()
+ model_name = filename[14:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ elif filename.startswith('open-mmlab://'):
+ model_urls = get_external_models()
+ model_name = filename[13:]
+ deprecated_urls = get_deprecated_model_names()
+ if model_name in deprecated_urls:
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
+ f'of open-mmlab://{deprecated_urls[model_name]}')
+ model_name = deprecated_urls[model_name]
+ model_url = model_urls[model_name]
+ # check if is url
+ if model_url.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(model_url)
+ else:
+ filename = osp.join(_get_mmcv_home(), model_url)
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ elif filename.startswith('mmcls://'):
+ model_urls = get_mmcls_models()
+ model_name = filename[8:]
+ checkpoint = load_url_dist(model_urls[model_name])
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
+ elif filename.startswith(('http://', 'https://')):
+ checkpoint = load_url_dist(filename)
+ elif filename.startswith('pavi://'):
+ model_path = filename[7:]
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
+ elif filename.startswith('s3://'):
+ checkpoint = load_fileclient_dist(
+ filename, backend='ceph', map_location=map_location)
+ else:
+ if not osp.isfile(filename):
+ raise IOError(f'{filename} is not a checkpoint file')
+ checkpoint = torch.load(filename, map_location=map_location)
+ return checkpoint
+
+
+def load_checkpoint(model,
+ filename,
+ map_location='cpu',
+ strict=False,
+ logger=None):
+ """Load checkpoint from a file or URI.
+
+ Args:
+ model (Module): Module to load checkpoint.
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
+ details.
+ map_location (str): Same as :func:`torch.load`.
+ strict (bool): Whether to allow different params for the model and
+ checkpoint.
+ logger (:mod:`logging.Logger` or None): The logger for error message.
+
+ Returns:
+ dict or OrderedDict: The loaded checkpoint.
+ """
+ checkpoint = _load_checkpoint(filename, map_location)
+ # OrderedDict is a subclass of dict
+ if not isinstance(checkpoint, dict):
+ raise RuntimeError(
+ f'No state_dict found in checkpoint file {filename}')
+ # get state_dict from checkpoint
+ if 'state_dict' in checkpoint:
+ state_dict = checkpoint['state_dict']
+ elif 'model' in checkpoint:
+ state_dict = checkpoint['model']
+ else:
+ state_dict = checkpoint
+ # strip prefix of state_dict
+ if list(state_dict.keys())[0].startswith('module.'):
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
+
+ # for MoBY, load model of online branch
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
+
+ # reshape absolute position embedding
+ if state_dict.get('absolute_pos_embed') is not None:
+ absolute_pos_embed = state_dict['absolute_pos_embed']
+ N1, L, C1 = absolute_pos_embed.size()
+ N2, C2, H, W = model.absolute_pos_embed.size()
+ if N1 != N2 or C1 != C2 or L != H*W:
+ logger.warning("Error in loading absolute_pos_embed, pass")
+ else:
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
+
+ # interpolate position bias table if needed
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
+ for table_key in relative_position_bias_table_keys:
+ table_pretrained = state_dict[table_key]
+ table_current = model.state_dict()[table_key]
+ L1, nH1 = table_pretrained.size()
+ L2, nH2 = table_current.size()
+ if nH1 != nH2:
+ logger.warning(f"Error in loading {table_key}, pass")
+ else:
+ if L1 != L2:
+ S1 = int(L1 ** 0.5)
+ S2 = int(L2 ** 0.5)
+ table_pretrained_resized = F.interpolate(
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
+ size=(S2, S2), mode='bicubic')
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
+
+ # load state_dict
+ load_state_dict(model, state_dict, strict, logger)
+ return checkpoint
+
+
+def weights_to_cpu(state_dict):
+ """Copy a model state_dict to cpu.
+
+ Args:
+ state_dict (OrderedDict): Model weights on GPU.
+
+ Returns:
+ OrderedDict: Model weights on GPU.
+ """
+ state_dict_cpu = OrderedDict()
+ for key, val in state_dict.items():
+ state_dict_cpu[key] = val.cpu()
+ return state_dict_cpu
+
+
+def _save_to_state_dict(module, destination, prefix, keep_vars):
+ """Saves module state to `destination` dictionary.
+
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (dict): A dict where state will be stored.
+ prefix (str): The prefix for parameters and buffers used in this
+ module.
+ """
+ for name, param in module._parameters.items():
+ if param is not None:
+ destination[prefix + name] = param if keep_vars else param.detach()
+ for name, buf in module._buffers.items():
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
+ if buf is not None:
+ destination[prefix + name] = buf if keep_vars else buf.detach()
+
+
+def get_state_dict(module, destination=None, prefix='', keep_vars=False):
+ """Returns a dictionary containing a whole state of the module.
+
+ Both parameters and persistent buffers (e.g. running averages) are
+ included. Keys are corresponding parameter and buffer names.
+
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
+ recursively check parallel module in case that the model has a complicated
+ structure, e.g., nn.Module(nn.Module(DDP)).
+
+ Args:
+ module (nn.Module): The module to generate state_dict.
+ destination (OrderedDict): Returned dict for the state of the
+ module.
+ prefix (str): Prefix of the key.
+ keep_vars (bool): Whether to keep the variable property of the
+ parameters. Default: False.
+
+ Returns:
+ dict: A dictionary containing a whole state of the module.
+ """
+ # recursively check parallel module in case that the model has a
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
+ if is_module_wrapper(module):
+ module = module.module
+
+ # below is the same as torch.nn.Module.state_dict()
+ if destination is None:
+ destination = OrderedDict()
+ destination._metadata = OrderedDict()
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
+ version=module._version)
+ _save_to_state_dict(module, destination, prefix, keep_vars)
+ for name, child in module._modules.items():
+ if child is not None:
+ get_state_dict(
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
+ for hook in module._state_dict_hooks.values():
+ hook_result = hook(module, destination, prefix, local_metadata)
+ if hook_result is not None:
+ destination = hook_result
+ return destination
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
+ ``optimizer``. By default ``meta`` will contain version and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ if filename.startswith('pavi://'):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
diff --git a/mmcv_custom/runner/__init__.py b/mmcv_custom/runner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c701cb016abe470611830dc960999970738352bb
--- /dev/null
+++ b/mmcv_custom/runner/__init__.py
@@ -0,0 +1,8 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+from .checkpoint import save_checkpoint
+from .epoch_based_runner import EpochBasedRunnerAmp
+
+
+__all__ = [
+ 'EpochBasedRunnerAmp', 'save_checkpoint'
+]
diff --git a/mmcv_custom/runner/checkpoint.py b/mmcv_custom/runner/checkpoint.py
new file mode 100644
index 0000000000000000000000000000000000000000..b04167e0fc5f16bc33e793830ebb9c4ef15ef1ed
--- /dev/null
+++ b/mmcv_custom/runner/checkpoint.py
@@ -0,0 +1,85 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import os.path as osp
+import time
+from tempfile import TemporaryDirectory
+
+import torch
+from torch.optim import Optimizer
+
+import mmcv
+from mmcv.parallel import is_module_wrapper
+from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
+
+try:
+ import apex
+except:
+ print('apex is not installed')
+
+
+def save_checkpoint(model, filename, optimizer=None, meta=None):
+ """Save checkpoint to file.
+
+ The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
+ ``optimizer``, ``amp``. By default ``meta`` will contain version
+ and time info.
+
+ Args:
+ model (Module): Module whose params are to be saved.
+ filename (str): Checkpoint filename.
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
+ meta (dict, optional): Metadata to be saved in checkpoint.
+ """
+ if meta is None:
+ meta = {}
+ elif not isinstance(meta, dict):
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
+
+ if is_module_wrapper(model):
+ model = model.module
+
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
+ # save class name to the meta
+ meta.update(CLASSES=model.CLASSES)
+
+ checkpoint = {
+ 'meta': meta,
+ 'state_dict': weights_to_cpu(get_state_dict(model))
+ }
+ # save optimizer state dict in the checkpoint
+ if isinstance(optimizer, Optimizer):
+ checkpoint['optimizer'] = optimizer.state_dict()
+ elif isinstance(optimizer, dict):
+ checkpoint['optimizer'] = {}
+ for name, optim in optimizer.items():
+ checkpoint['optimizer'][name] = optim.state_dict()
+
+ # save amp state dict in the checkpoint
+ checkpoint['amp'] = apex.amp.state_dict()
+
+ if filename.startswith('pavi://'):
+ try:
+ from pavi import modelcloud
+ from pavi.exception import NodeNotFoundError
+ except ImportError:
+ raise ImportError(
+ 'Please install pavi to load checkpoint from modelcloud.')
+ model_path = filename[7:]
+ root = modelcloud.Folder()
+ model_dir, model_name = osp.split(model_path)
+ try:
+ model = modelcloud.get(model_dir)
+ except NodeNotFoundError:
+ model = root.create_training_model(model_dir)
+ with TemporaryDirectory() as tmp_dir:
+ checkpoint_file = osp.join(tmp_dir, model_name)
+ with open(checkpoint_file, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
+ model.create_file(checkpoint_file, name=model_name)
+ else:
+ mmcv.mkdir_or_exist(osp.dirname(filename))
+ # immediately flush buffer
+ with open(filename, 'wb') as f:
+ torch.save(checkpoint, f)
+ f.flush()
diff --git a/mmcv_custom/runner/epoch_based_runner.py b/mmcv_custom/runner/epoch_based_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cdf3fa05639f7fde652090be9dbf78b48790744
--- /dev/null
+++ b/mmcv_custom/runner/epoch_based_runner.py
@@ -0,0 +1,104 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+import os.path as osp
+import platform
+import shutil
+
+import torch
+from torch.optim import Optimizer
+
+import mmcv
+from mmcv.runner import RUNNERS, EpochBasedRunner
+from .checkpoint import save_checkpoint
+
+try:
+ import apex
+except:
+ print('apex is not installed')
+
+
+@RUNNERS.register_module()
+class EpochBasedRunnerAmp(EpochBasedRunner):
+ """Epoch-based Runner with AMP support.
+
+ This runner train models epoch by epoch.
+ """
+
+ def save_checkpoint(self,
+ out_dir,
+ filename_tmpl='epoch_{}.pth',
+ save_optimizer=True,
+ meta=None,
+ create_symlink=True):
+ """Save the checkpoint.
+
+ Args:
+ out_dir (str): The directory that checkpoints are saved.
+ filename_tmpl (str, optional): The checkpoint filename template,
+ which contains a placeholder for the epoch number.
+ Defaults to 'epoch_{}.pth'.
+ save_optimizer (bool, optional): Whether to save the optimizer to
+ the checkpoint. Defaults to True.
+ meta (dict, optional): The meta information to be saved in the
+ checkpoint. Defaults to None.
+ create_symlink (bool, optional): Whether to create a symlink
+ "latest.pth" to point to the latest checkpoint.
+ Defaults to True.
+ """
+ if meta is None:
+ meta = dict(epoch=self.epoch + 1, iter=self.iter)
+ elif isinstance(meta, dict):
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
+ else:
+ raise TypeError(
+ f'meta should be a dict or None, but got {type(meta)}')
+ if self.meta is not None:
+ meta.update(self.meta)
+
+ filename = filename_tmpl.format(self.epoch + 1)
+ filepath = osp.join(out_dir, filename)
+ optimizer = self.optimizer if save_optimizer else None
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
+ # in some environments, `os.symlink` is not supported, you may need to
+ # set `create_symlink` to False
+ if create_symlink:
+ dst_file = osp.join(out_dir, 'latest.pth')
+ if platform.system() != 'Windows':
+ mmcv.symlink(filename, dst_file)
+ else:
+ shutil.copy(filepath, dst_file)
+
+ def resume(self,
+ checkpoint,
+ resume_optimizer=True,
+ map_location='default'):
+ if map_location == 'default':
+ if torch.cuda.is_available():
+ device_id = torch.cuda.current_device()
+ checkpoint = self.load_checkpoint(
+ checkpoint,
+ map_location=lambda storage, loc: storage.cuda(device_id))
+ else:
+ checkpoint = self.load_checkpoint(checkpoint)
+ else:
+ checkpoint = self.load_checkpoint(
+ checkpoint, map_location=map_location)
+
+ self._epoch = checkpoint['meta']['epoch']
+ self._iter = checkpoint['meta']['iter']
+ if 'optimizer' in checkpoint and resume_optimizer:
+ if isinstance(self.optimizer, Optimizer):
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
+ elif isinstance(self.optimizer, dict):
+ for k in self.optimizer.keys():
+ self.optimizer[k].load_state_dict(
+ checkpoint['optimizer'][k])
+ else:
+ raise TypeError(
+ 'Optimizer should be dict or torch.optim.Optimizer '
+ f'but got {type(self.optimizer)}')
+
+ if 'amp' in checkpoint:
+ apex.amp.load_state_dict(checkpoint['amp'])
+ self.logger.info('load amp state dict')
+
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
diff --git a/mmdet/__init__.py b/mmdet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce2930f62a0091e06b37575b96db2ae51ca7908e
--- /dev/null
+++ b/mmdet/__init__.py
@@ -0,0 +1,28 @@
+import mmcv
+
+from .version import __version__, short_version
+
+
+def digit_version(version_str):
+ digit_version = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ digit_version.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ digit_version.append(int(patch_version[0]) - 1)
+ digit_version.append(int(patch_version[1]))
+ return digit_version
+
+
+mmcv_minimum_version = '1.2.4'
+mmcv_maximum_version = '1.4.0'
+mmcv_version = digit_version(mmcv.__version__)
+
+
+assert (mmcv_version >= digit_version(mmcv_minimum_version)
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
+
+__all__ = ['__version__', 'short_version']
diff --git a/mmdet/apis/__init__.py b/mmdet/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d8035b74877fdeccaa41cbc10a9f1f9924eac85
--- /dev/null
+++ b/mmdet/apis/__init__.py
@@ -0,0 +1,10 @@
+from .inference import (async_inference_detector, inference_detector,
+ init_detector, show_result_pyplot)
+from .test import multi_gpu_test, single_gpu_test
+from .train import get_root_logger, set_random_seed, train_detector
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
+ 'async_inference_detector', 'inference_detector', 'show_result_pyplot',
+ 'multi_gpu_test', 'single_gpu_test'
+]
diff --git a/mmdet/apis/inference.py b/mmdet/apis/inference.py
new file mode 100644
index 0000000000000000000000000000000000000000..464d1e2dec8bd30304ec8018922681fe63b77970
--- /dev/null
+++ b/mmdet/apis/inference.py
@@ -0,0 +1,217 @@
+import warnings
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.ops import RoIPool
+from mmcv.parallel import collate, scatter
+from mmcv.runner import load_checkpoint
+
+from mmdet.core import get_classes
+from mmdet.datasets import replace_ImageToTensor
+from mmdet.datasets.pipelines import Compose
+from mmdet.models import build_detector
+
+
+def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
+ """Initialize a detector from config file.
+
+ Args:
+ config (str or :obj:`mmcv.Config`): Config file path or the config
+ object.
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
+ will not load any weights.
+ cfg_options (dict): Options to override some settings in the used
+ config.
+
+ Returns:
+ nn.Module: The constructed detector.
+ """
+ if isinstance(config, str):
+ config = mmcv.Config.fromfile(config)
+ elif not isinstance(config, mmcv.Config):
+ raise TypeError('config must be a filename or Config object, '
+ f'but got {type(config)}')
+ if cfg_options is not None:
+ config.merge_from_dict(cfg_options)
+ config.model.pretrained = None
+ config.model.train_cfg = None
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
+ if checkpoint is not None:
+ map_loc = 'cpu' if device == 'cpu' else None
+ checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
+ if 'CLASSES' in checkpoint.get('meta', {}):
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ else:
+ warnings.simplefilter('once')
+ warnings.warn('Class names are not saved in the checkpoint\'s '
+ 'meta data, use COCO classes by default.')
+ model.CLASSES = get_classes('coco')
+ model.cfg = config # save the config in the model for convenience
+ model.to(device)
+ model.eval()
+ return model
+
+
+class LoadImage(object):
+ """Deprecated.
+
+ A simple pipeline to load image.
+ """
+
+ def __call__(self, results):
+ """Call function to load images into results.
+
+ Args:
+ results (dict): A result dict contains the file name
+ of the image to be read.
+ Returns:
+ dict: ``results`` will be returned containing loaded image.
+ """
+ warnings.simplefilter('once')
+ warnings.warn('`LoadImage` is deprecated and will be removed in '
+ 'future releases. You may use `LoadImageFromWebcam` '
+ 'from `mmdet.datasets.pipelines.` instead.')
+ if isinstance(results['img'], str):
+ results['filename'] = results['img']
+ results['ori_filename'] = results['img']
+ else:
+ results['filename'] = None
+ results['ori_filename'] = None
+ img = mmcv.imread(results['img'])
+ results['img'] = img
+ results['img_fields'] = ['img']
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ return results
+
+
+def inference_detector(model, imgs):
+ """Inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
+ Either image files or loaded images.
+
+ Returns:
+ If imgs is a list or tuple, the same length list type results
+ will be returned, otherwise return the detection results directly.
+ """
+
+ if isinstance(imgs, (list, tuple)):
+ is_batch = True
+ else:
+ imgs = [imgs]
+ is_batch = False
+
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+
+ if isinstance(imgs[0], np.ndarray):
+ cfg = cfg.copy()
+ # set loading pipeline type
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
+
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
+ test_pipeline = Compose(cfg.data.test.pipeline)
+
+ datas = []
+ for img in imgs:
+ # prepare data
+ if isinstance(img, np.ndarray):
+ # directly add img
+ data = dict(img=img)
+ else:
+ # add information into dict
+ data = dict(img_info=dict(filename=img), img_prefix=None)
+ # build the data pipeline
+ data = test_pipeline(data)
+ datas.append(data)
+
+ data = collate(datas, samples_per_gpu=len(imgs))
+ # just get the actual data from DataContainer
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
+ data['img'] = [img.data[0] for img in data['img']]
+ if next(model.parameters()).is_cuda:
+ # scatter to specified GPU
+ data = scatter(data, [device])[0]
+ else:
+ for m in model.modules():
+ assert not isinstance(
+ m, RoIPool
+ ), 'CPU inference with RoIPool is not supported currently.'
+
+ # forward the model
+ with torch.no_grad():
+ results = model(return_loss=False, rescale=True, **data)
+
+ if not is_batch:
+ return results[0]
+ else:
+ return results
+
+
+async def async_inference_detector(model, img):
+ """Async inference image(s) with the detector.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ img (str | ndarray): Either image files or loaded images.
+
+ Returns:
+ Awaitable detection results.
+ """
+ cfg = model.cfg
+ device = next(model.parameters()).device # model device
+ # prepare data
+ if isinstance(img, np.ndarray):
+ # directly add img
+ data = dict(img=img)
+ cfg = cfg.copy()
+ # set loading pipeline type
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
+ else:
+ # add information into dict
+ data = dict(img_info=dict(filename=img), img_prefix=None)
+ # build the data pipeline
+ test_pipeline = Compose(cfg.data.test.pipeline)
+ data = test_pipeline(data)
+ data = scatter(collate([data], samples_per_gpu=1), [device])[0]
+
+ # We don't restore `torch.is_grad_enabled()` value during concurrent
+ # inference since execution can overlap
+ torch.set_grad_enabled(False)
+ result = await model.aforward_test(rescale=True, **data)
+ return result
+
+
+def show_result_pyplot(model,
+ img,
+ result,
+ score_thr=0.3,
+ title='result',
+ wait_time=0):
+ """Visualize the detection results on the image.
+
+ Args:
+ model (nn.Module): The loaded detector.
+ img (str or np.ndarray): Image filename or loaded image.
+ result (tuple[list] or list): The detection result, can be either
+ (bbox, segm) or just bbox.
+ score_thr (float): The threshold to visualize the bboxes and masks.
+ title (str): Title of the pyplot figure.
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ """
+ if hasattr(model, 'module'):
+ model = model.module
+ model.show_result(
+ img,
+ result,
+ score_thr=score_thr,
+ show=True,
+ wait_time=wait_time,
+ win_name=title,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241))
diff --git a/mmdet/apis/test.py b/mmdet/apis/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..68b2347e4c2c12b23c7ebc0c0b066735d23cda1b
--- /dev/null
+++ b/mmdet/apis/test.py
@@ -0,0 +1,189 @@
+import os.path as osp
+import pickle
+import shutil
+import tempfile
+import time
+
+import mmcv
+import torch
+import torch.distributed as dist
+from mmcv.image import tensor2imgs
+from mmcv.runner import get_dist_info
+
+from mmdet.core import encode_mask_results
+
+
+def single_gpu_test(model,
+ data_loader,
+ show=False,
+ out_dir=None,
+ show_score_thr=0.3):
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+
+ batch_size = len(result)
+ if show or out_dir:
+ if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
+ img_tensor = data['img'][0]
+ else:
+ img_tensor = data['img'][0].data[0]
+ img_metas = data['img_metas'][0].data[0]
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
+ assert len(imgs) == len(img_metas)
+
+ for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
+ h, w, _ = img_meta['img_shape']
+ img_show = img[:h, :w, :]
+
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
+
+ if out_dir:
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
+ else:
+ out_file = None
+ model.module.show_result(
+ img_show,
+ result[i],
+ show=show,
+ out_file=out_file,
+ score_thr=show_score_thr)
+
+ # encode mask results
+ if isinstance(result[0], tuple):
+ result = [(bbox_results, encode_mask_results(mask_results))
+ for bbox_results, mask_results in result]
+ results.extend(result)
+
+ for _ in range(batch_size):
+ prog_bar.update()
+ return results
+
+
+def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
+ """Test model with multiple gpus.
+
+ This method tests model with multiple gpus and collects the results
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
+ it encodes results to gpu tensors and use gpu communication for results
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
+ and collects them by the rank 0 worker.
+
+ Args:
+ model (nn.Module): Model to be tested.
+ data_loader (nn.Dataloader): Pytorch data loader.
+ tmpdir (str): Path of directory to save the temporary results from
+ different gpus under cpu mode.
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
+
+ Returns:
+ list: The prediction results.
+ """
+ model.eval()
+ results = []
+ dataset = data_loader.dataset
+ rank, world_size = get_dist_info()
+ if rank == 0:
+ prog_bar = mmcv.ProgressBar(len(dataset))
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
+ for i, data in enumerate(data_loader):
+ with torch.no_grad():
+ result = model(return_loss=False, rescale=True, **data)
+ # encode mask results
+ if isinstance(result[0], tuple):
+ result = [(bbox_results, encode_mask_results(mask_results))
+ for bbox_results, mask_results in result]
+ results.extend(result)
+
+ if rank == 0:
+ batch_size = len(result)
+ for _ in range(batch_size * world_size):
+ prog_bar.update()
+
+ # collect results from all ranks
+ if gpu_collect:
+ results = collect_results_gpu(results, len(dataset))
+ else:
+ results = collect_results_cpu(results, len(dataset), tmpdir)
+ return results
+
+
+def collect_results_cpu(result_part, size, tmpdir=None):
+ rank, world_size = get_dist_info()
+ # create a tmp dir if it is not specified
+ if tmpdir is None:
+ MAX_LEN = 512
+ # 32 is whitespace
+ dir_tensor = torch.full((MAX_LEN, ),
+ 32,
+ dtype=torch.uint8,
+ device='cuda')
+ if rank == 0:
+ mmcv.mkdir_or_exist('.dist_test')
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
+ tmpdir = torch.tensor(
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
+ dir_tensor[:len(tmpdir)] = tmpdir
+ dist.broadcast(dir_tensor, 0)
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
+ else:
+ mmcv.mkdir_or_exist(tmpdir)
+ # dump the part result to the dir
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
+ dist.barrier()
+ # collect all parts
+ if rank != 0:
+ return None
+ else:
+ # load results of all parts from tmp dir
+ part_list = []
+ for i in range(world_size):
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
+ part_list.append(mmcv.load(part_file))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ # remove tmp dir
+ shutil.rmtree(tmpdir)
+ return ordered_results
+
+
+def collect_results_gpu(result_part, size):
+ rank, world_size = get_dist_info()
+ # dump result part to tensor with pickle
+ part_tensor = torch.tensor(
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
+ # gather all result part tensor shape
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
+ dist.all_gather(shape_list, shape_tensor)
+ # padding result part tensor to max length
+ shape_max = torch.tensor(shape_list).max()
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
+ part_send[:shape_tensor[0]] = part_tensor
+ part_recv_list = [
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
+ ]
+ # gather all result part
+ dist.all_gather(part_recv_list, part_send)
+
+ if rank == 0:
+ part_list = []
+ for recv, shape in zip(part_recv_list, shape_list):
+ part_list.append(
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
+ # sort the results
+ ordered_results = []
+ for res in zip(*part_list):
+ ordered_results.extend(list(res))
+ # the dataloader may pad some samples
+ ordered_results = ordered_results[:size]
+ return ordered_results
diff --git a/mmdet/apis/train.py b/mmdet/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f2f1f95c0a8e7c9232f7aa490e8104f8e37c4f5
--- /dev/null
+++ b/mmdet/apis/train.py
@@ -0,0 +1,185 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
+ Fp16OptimizerHook, OptimizerHook, build_optimizer,
+ build_runner)
+from mmcv.utils import build_from_cfg
+
+from mmdet.core import DistEvalHook, EvalHook
+from mmdet.datasets import (build_dataloader, build_dataset,
+ replace_ImageToTensor)
+from mmdet.utils import get_root_logger
+from mmcv_custom.runner import EpochBasedRunnerAmp
+try:
+ import apex
+except:
+ print('apex is not installed')
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_detector(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ if 'imgs_per_gpu' in cfg.data:
+ logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
+ 'Please use "samples_per_gpu" instead')
+ if 'samples_per_gpu' in cfg.data:
+ logger.warning(
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
+ f'={cfg.data.imgs_per_gpu} is used in this experiments')
+ else:
+ logger.warning(
+ 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
+ f'{cfg.data.imgs_per_gpu} in this experiments')
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
+
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed) for ds in dataset
+ ]
+
+ # build optimizer
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ # use apex fp16 optimizer
+ if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
+ if cfg.optimizer_config.get("use_fp16", False):
+ model, optimizer = apex.amp.initialize(
+ model.cuda(), optimizer, opt_level="O1")
+ for m in model.modules():
+ if hasattr(m, "fp16_enabled"):
+ m.fp16_enabled = True
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ if 'runner' not in cfg:
+ cfg.runner = {
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.total_epochs
+ }
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ else:
+ if 'total_epochs' in cfg:
+ assert cfg.total_epochs == cfg.runner.max_epochs
+
+ # build runner
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+ if distributed:
+ if isinstance(runner, EpochBasedRunner):
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ # Support batch_size > 1 in validation
+ val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
+ if val_samples_per_gpu > 1:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.val.pipeline = replace_ImageToTensor(
+ cfg.data.val.pipeline)
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=val_samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+
+ # user-defined hooks
+ if cfg.get('custom_hooks', None):
+ custom_hooks = cfg.custom_hooks
+ assert isinstance(custom_hooks, list), \
+ f'custom_hooks expect list type, but got {type(custom_hooks)}'
+ for hook_cfg in cfg.custom_hooks:
+ assert isinstance(hook_cfg, dict), \
+ 'Each item in custom_hooks expects dict type, but got ' \
+ f'{type(hook_cfg)}'
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = build_from_cfg(hook_cfg, HOOKS)
+ runner.register_hook(hook, priority=priority)
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/mmdet/core/__init__.py b/mmdet/core/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e812391e23894ef296755381386d4849f774418a
--- /dev/null
+++ b/mmdet/core/__init__.py
@@ -0,0 +1,7 @@
+from .anchor import * # noqa: F401, F403
+from .bbox import * # noqa: F401, F403
+from .evaluation import * # noqa: F401, F403
+from .export import * # noqa: F401, F403
+from .mask import * # noqa: F401, F403
+from .post_processing import * # noqa: F401, F403
+from .utils import * # noqa: F401, F403
diff --git a/mmdet/core/anchor/__init__.py b/mmdet/core/anchor/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5838ff3eefb03bc83928fa13848cea9ff8647827
--- /dev/null
+++ b/mmdet/core/anchor/__init__.py
@@ -0,0 +1,11 @@
+from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
+ YOLOAnchorGenerator)
+from .builder import ANCHOR_GENERATORS, build_anchor_generator
+from .point_generator import PointGenerator
+from .utils import anchor_inside_flags, calc_region, images_to_levels
+
+__all__ = [
+ 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
+ 'PointGenerator', 'images_to_levels', 'calc_region',
+ 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
+]
diff --git a/mmdet/core/anchor/anchor_generator.py b/mmdet/core/anchor/anchor_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..388d2608b8138da13d1208b99595fbd1db59d178
--- /dev/null
+++ b/mmdet/core/anchor/anchor_generator.py
@@ -0,0 +1,727 @@
+import mmcv
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+from .builder import ANCHOR_GENERATORS
+
+
+@ANCHOR_GENERATORS.register_module()
+class AnchorGenerator(object):
+ """Standard anchor generator for 2D anchor-based detectors.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels in order (w, h).
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ scales (list[int] | None): Anchor scales for anchors in a single level.
+ It cannot be set at the same time if `octave_base_scale` and
+ `scales_per_octave` are set.
+ base_sizes (list[int] | None): The basic sizes
+ of anchors in multiple levels.
+ If None is given, strides will be used as base_sizes.
+ (If strides are non square, the shortest stride is taken.)
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. By default it is True in V2.0
+ octave_base_scale (int): The base scale of octave.
+ scales_per_octave (int): Number of scales for each octave.
+ `octave_base_scale` and `scales_per_octave` are usually used in
+ retinanet and the `scales` should be None when they are set.
+ centers (list[tuple[float, float]] | None): The centers of the anchor
+ relative to the feature grid center in multiple feature levels.
+ By default it is set to be None and not used. If a list of tuple of
+ float is given, they will be used to shift the centers of anchors.
+ center_offset (float): The offset of center in proportion to anchors'
+ width and height. By default it is 0 in V2.0.
+
+ Examples:
+ >>> from mmdet.core import AnchorGenerator
+ >>> self = AnchorGenerator([16], [1.], [1.], [9])
+ >>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]])]
+ >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
+ >>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
+ >>> print(all_anchors)
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
+ [11.5000, -4.5000, 20.5000, 4.5000],
+ [-4.5000, 11.5000, 4.5000, 20.5000],
+ [11.5000, 11.5000, 20.5000, 20.5000]]), \
+ tensor([[-9., -9., 9., 9.]])]
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ scales=None,
+ base_sizes=None,
+ scale_major=True,
+ octave_base_scale=None,
+ scales_per_octave=None,
+ centers=None,
+ center_offset=0.):
+ # check center and center_offset
+ if center_offset != 0:
+ assert centers is None, 'center cannot be set when center_offset' \
+ f'!=0, {centers} is given.'
+ if not (0 <= center_offset <= 1):
+ raise ValueError('center_offset should be in range [0, 1], '
+ f'{center_offset} is given.')
+ if centers is not None:
+ assert len(centers) == len(strides), \
+ 'The number of strides should be the same as centers, got ' \
+ f'{strides} and {centers}'
+
+ # calculate base sizes of anchors
+ self.strides = [_pair(stride) for stride in strides]
+ self.base_sizes = [min(stride) for stride in self.strides
+ ] if base_sizes is None else base_sizes
+ assert len(self.base_sizes) == len(self.strides), \
+ 'The number of strides should be the same as base sizes, got ' \
+ f'{self.strides} and {self.base_sizes}'
+
+ # calculate scales of anchors
+ assert ((octave_base_scale is not None
+ and scales_per_octave is not None) ^ (scales is not None)), \
+ 'scales and octave_base_scale with scales_per_octave cannot' \
+ ' be set at the same time'
+ if scales is not None:
+ self.scales = torch.Tensor(scales)
+ elif octave_base_scale is not None and scales_per_octave is not None:
+ octave_scales = np.array(
+ [2**(i / scales_per_octave) for i in range(scales_per_octave)])
+ scales = octave_scales * octave_base_scale
+ self.scales = torch.Tensor(scales)
+ else:
+ raise ValueError('Either scales or octave_base_scale with '
+ 'scales_per_octave should be set')
+
+ self.octave_base_scale = octave_base_scale
+ self.scales_per_octave = scales_per_octave
+ self.ratios = torch.Tensor(ratios)
+ self.scale_major = scale_major
+ self.centers = centers
+ self.center_offset = center_offset
+ self.base_anchors = self.gen_base_anchors()
+
+ @property
+ def num_base_anchors(self):
+ """list[int]: total number of base anchors in a feature grid"""
+ return [base_anchors.size(0) for base_anchors in self.base_anchors]
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.strides)
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_size in enumerate(self.base_sizes):
+ center = None
+ if self.centers is not None:
+ center = self.centers[i]
+ multi_level_base_anchors.append(
+ self.gen_single_level_base_anchors(
+ base_size,
+ scales=self.scales,
+ ratios=self.ratios,
+ center=center))
+ return multi_level_base_anchors
+
+ def gen_single_level_base_anchors(self,
+ base_size,
+ scales,
+ ratios,
+ center=None):
+ """Generate base anchors of a single level.
+
+ Args:
+ base_size (int | float): Basic size of an anchor.
+ scales (torch.Tensor): Scales of the anchor.
+ ratios (torch.Tensor): The ratio between between the height
+ and width of anchors in a single level.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature maps.
+ """
+ w = base_size
+ h = base_size
+ if center is None:
+ x_center = self.center_offset * w
+ y_center = self.center_offset * h
+ else:
+ x_center, y_center = center
+
+ h_ratios = torch.sqrt(ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
+ else:
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchors = [
+ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
+ y_center + 0.5 * hs
+ ]
+ base_anchors = torch.stack(base_anchors, dim=-1)
+
+ return base_anchors
+
+ def _meshgrid(self, x, y, row_major=True):
+ """Generate mesh grid of x and y.
+
+ Args:
+ x (torch.Tensor): Grids of x dimension.
+ y (torch.Tensor): Grids of y dimension.
+ row_major (bool, optional): Whether to return y grids first.
+ Defaults to True.
+
+ Returns:
+ tuple[torch.Tensor]: The mesh grids of x and y.
+ """
+ # use shape instead of len to keep tracing while exporting to onnx
+ xx = x.repeat(y.shape[0])
+ yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
+ if row_major:
+ return xx, yy
+ else:
+ return yy, xx
+
+ def grid_anchors(self, featmap_sizes, device='cuda'):
+ """Generate grid anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list[tuple]): List of feature map sizes in
+ multiple feature levels.
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list[torch.Tensor]: Anchors in multiple feature levels. \
+ The sizes of each tensor should be [N, 4], where \
+ N = width * height * num_base_anchors, width and height \
+ are the sizes of the corresponding feature level, \
+ num_base_anchors is the number of anchors for that level.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_anchors = []
+ for i in range(self.num_levels):
+ anchors = self.single_level_grid_anchors(
+ self.base_anchors[i].to(device),
+ featmap_sizes[i],
+ self.strides[i],
+ device=device)
+ multi_level_anchors.append(anchors)
+ return multi_level_anchors
+
+ def single_level_grid_anchors(self,
+ base_anchors,
+ featmap_size,
+ stride=(16, 16),
+ device='cuda'):
+ """Generate grid anchors of a single level.
+
+ Note:
+ This function is usually called by method ``self.grid_anchors``.
+
+ Args:
+ base_anchors (torch.Tensor): The base anchors of a feature grid.
+ featmap_size (tuple[int]): Size of the feature maps.
+ stride (tuple[int], optional): Stride of the feature map in order
+ (w, h). Defaults to (16, 16).
+ device (str, optional): Device the tensor will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: Anchors in the overall feature maps.
+ """
+ # keep as Tensor, so that we can covert to ONNX correctly
+ feat_h, feat_w = featmap_size
+ shift_x = torch.arange(0, feat_w, device=device) * stride[0]
+ shift_y = torch.arange(0, feat_h, device=device) * stride[1]
+
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
+ shifts = shifts.type_as(base_anchors)
+ # first feat_w elements correspond to the first row of shifts
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
+
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
+ all_anchors = all_anchors.view(-1, 4)
+ # first A rows correspond to A anchors of (0, 0) in feature map,
+ # then (0, 1), (0, 2), ...
+ return all_anchors
+
+ def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
+ """Generate valid flags of anchors in multiple feature levels.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in
+ multiple feature levels.
+ pad_shape (tuple): The padded shape of the image.
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): Valid flags of anchors in multiple levels.
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_flags = []
+ for i in range(self.num_levels):
+ anchor_stride = self.strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = pad_shape[:2]
+ valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
+ valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
+ flags = self.single_level_valid_flags((feat_h, feat_w),
+ (valid_feat_h, valid_feat_w),
+ self.num_base_anchors[i],
+ device=device)
+ multi_level_flags.append(flags)
+ return multi_level_flags
+
+ def single_level_valid_flags(self,
+ featmap_size,
+ valid_size,
+ num_base_anchors,
+ device='cuda'):
+ """Generate the valid flags of anchor in a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps.
+ valid_size (tuple[int]): The valid size of the feature maps.
+ num_base_anchors (int): The number of base anchors.
+ device (str, optional): Device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each anchor in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ valid = valid[:, None].expand(valid.size(0),
+ num_base_anchors).contiguous().view(-1)
+ return valid
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ indent_str = ' '
+ repr_str = self.__class__.__name__ + '(\n'
+ repr_str += f'{indent_str}strides={self.strides},\n'
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
+ repr_str += f'{indent_str}octave_base_scale='
+ repr_str += f'{self.octave_base_scale},\n'
+ repr_str += f'{indent_str}scales_per_octave='
+ repr_str += f'{self.scales_per_octave},\n'
+ repr_str += f'{indent_str}num_levels={self.num_levels}\n'
+ repr_str += f'{indent_str}centers={self.centers},\n'
+ repr_str += f'{indent_str}center_offset={self.center_offset})'
+ return repr_str
+
+
+@ANCHOR_GENERATORS.register_module()
+class SSDAnchorGenerator(AnchorGenerator):
+ """Anchor generator for SSD.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels.
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ basesize_ratio_range (tuple(float)): Ratio range of anchors.
+ input_size (int): Size of feature map, 300 for SSD300,
+ 512 for SSD512.
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. It is always set to be False in SSD.
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ basesize_ratio_range,
+ input_size=300,
+ scale_major=True):
+ assert len(strides) == len(ratios)
+ assert mmcv.is_tuple_of(basesize_ratio_range, float)
+
+ self.strides = [_pair(stride) for stride in strides]
+ self.input_size = input_size
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
+ for stride in self.strides]
+ self.basesize_ratio_range = basesize_ratio_range
+
+ # calculate anchor ratios and sizes
+ min_ratio, max_ratio = basesize_ratio_range
+ min_ratio = int(min_ratio * 100)
+ max_ratio = int(max_ratio * 100)
+ step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
+ min_sizes = []
+ max_sizes = []
+ for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
+ min_sizes.append(int(self.input_size * ratio / 100))
+ max_sizes.append(int(self.input_size * (ratio + step) / 100))
+ if self.input_size == 300:
+ if basesize_ratio_range[0] == 0.15: # SSD300 COCO
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
+ elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
+ min_sizes.insert(0, int(self.input_size * 10 / 100))
+ max_sizes.insert(0, int(self.input_size * 20 / 100))
+ else:
+ raise ValueError(
+ 'basesize_ratio_range[0] should be either 0.15'
+ 'or 0.2 when input_size is 300, got '
+ f'{basesize_ratio_range[0]}.')
+ elif self.input_size == 512:
+ if basesize_ratio_range[0] == 0.1: # SSD512 COCO
+ min_sizes.insert(0, int(self.input_size * 4 / 100))
+ max_sizes.insert(0, int(self.input_size * 10 / 100))
+ elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
+ else:
+ raise ValueError('basesize_ratio_range[0] should be either 0.1'
+ 'or 0.15 when input_size is 512, got'
+ f' {basesize_ratio_range[0]}.')
+ else:
+ raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
+ f', got {self.input_size}.')
+
+ anchor_ratios = []
+ anchor_scales = []
+ for k in range(len(self.strides)):
+ scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
+ anchor_ratio = [1.]
+ for r in ratios[k]:
+ anchor_ratio += [1 / r, r] # 4 or 6 ratio
+ anchor_ratios.append(torch.Tensor(anchor_ratio))
+ anchor_scales.append(torch.Tensor(scales))
+
+ self.base_sizes = min_sizes
+ self.scales = anchor_scales
+ self.ratios = anchor_ratios
+ self.scale_major = scale_major
+ self.center_offset = 0
+ self.base_anchors = self.gen_base_anchors()
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_size in enumerate(self.base_sizes):
+ base_anchors = self.gen_single_level_base_anchors(
+ base_size,
+ scales=self.scales[i],
+ ratios=self.ratios[i],
+ center=self.centers[i])
+ indices = list(range(len(self.ratios[i])))
+ indices.insert(1, len(indices))
+ base_anchors = torch.index_select(base_anchors, 0,
+ torch.LongTensor(indices))
+ multi_level_base_anchors.append(base_anchors)
+ return multi_level_base_anchors
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ indent_str = ' '
+ repr_str = self.__class__.__name__ + '(\n'
+ repr_str += f'{indent_str}strides={self.strides},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
+ repr_str += f'{indent_str}input_size={self.input_size},\n'
+ repr_str += f'{indent_str}scales={self.scales},\n'
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
+ repr_str += f'{indent_str}num_levels={self.num_levels},\n'
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
+ repr_str += f'{indent_str}basesize_ratio_range='
+ repr_str += f'{self.basesize_ratio_range})'
+ return repr_str
+
+
+@ANCHOR_GENERATORS.register_module()
+class LegacyAnchorGenerator(AnchorGenerator):
+ """Legacy anchor generator used in MMDetection V1.x.
+
+ Note:
+ Difference to the V2.0 anchor generator:
+
+ 1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
+ 2. The width/height are minused by 1 when calculating the anchors' \
+ centers and corners to meet the V1.x coordinate system.
+ 3. The anchors' corners are quantized.
+
+ Args:
+ strides (list[int] | list[tuple[int]]): Strides of anchors
+ in multiple feature levels.
+ ratios (list[float]): The list of ratios between the height and width
+ of anchors in a single level.
+ scales (list[int] | None): Anchor scales for anchors in a single level.
+ It cannot be set at the same time if `octave_base_scale` and
+ `scales_per_octave` are set.
+ base_sizes (list[int]): The basic sizes of anchors in multiple levels.
+ If None is given, strides will be used to generate base_sizes.
+ scale_major (bool): Whether to multiply scales first when generating
+ base anchors. If true, the anchors in the same row will have the
+ same scales. By default it is True in V2.0
+ octave_base_scale (int): The base scale of octave.
+ scales_per_octave (int): Number of scales for each octave.
+ `octave_base_scale` and `scales_per_octave` are usually used in
+ retinanet and the `scales` should be None when they are set.
+ centers (list[tuple[float, float]] | None): The centers of the anchor
+ relative to the feature grid center in multiple feature levels.
+ By default it is set to be None and not used. It a list of float
+ is given, this list will be used to shift the centers of anchors.
+ center_offset (float): The offset of center in propotion to anchors'
+ width and height. By default it is 0.5 in V2.0 but it should be 0.5
+ in v1.x models.
+
+ Examples:
+ >>> from mmdet.core import LegacyAnchorGenerator
+ >>> self = LegacyAnchorGenerator(
+ >>> [16], [1.], [1.], [9], center_offset=0.5)
+ >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
+ >>> print(all_anchors)
+ [tensor([[ 0., 0., 8., 8.],
+ [16., 0., 24., 8.],
+ [ 0., 16., 8., 24.],
+ [16., 16., 24., 24.]])]
+ """
+
+ def gen_single_level_base_anchors(self,
+ base_size,
+ scales,
+ ratios,
+ center=None):
+ """Generate base anchors of a single level.
+
+ Note:
+ The width/height of anchors are minused by 1 when calculating \
+ the centers and corners to meet the V1.x coordinate system.
+
+ Args:
+ base_size (int | float): Basic size of an anchor.
+ scales (torch.Tensor): Scales of the anchor.
+ ratios (torch.Tensor): The ratio between between the height.
+ and width of anchors in a single level.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature map.
+ """
+ w = base_size
+ h = base_size
+ if center is None:
+ x_center = self.center_offset * (w - 1)
+ y_center = self.center_offset * (h - 1)
+ else:
+ x_center, y_center = center
+
+ h_ratios = torch.sqrt(ratios)
+ w_ratios = 1 / h_ratios
+ if self.scale_major:
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
+ else:
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchors = [
+ x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
+ x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
+ ]
+ base_anchors = torch.stack(base_anchors, dim=-1).round()
+
+ return base_anchors
+
+
+@ANCHOR_GENERATORS.register_module()
+class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
+ """Legacy anchor generator used in MMDetection V1.x.
+
+ The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
+ can be found in `LegacyAnchorGenerator`.
+ """
+
+ def __init__(self,
+ strides,
+ ratios,
+ basesize_ratio_range,
+ input_size=300,
+ scale_major=True):
+ super(LegacySSDAnchorGenerator,
+ self).__init__(strides, ratios, basesize_ratio_range, input_size,
+ scale_major)
+ self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
+ for stride in strides]
+ self.base_anchors = self.gen_base_anchors()
+
+
+@ANCHOR_GENERATORS.register_module()
+class YOLOAnchorGenerator(AnchorGenerator):
+ """Anchor generator for YOLO.
+
+ Args:
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
+ in multiple feature levels.
+ base_sizes (list[list[tuple[int, int]]]): The basic sizes
+ of anchors in multiple levels.
+ """
+
+ def __init__(self, strides, base_sizes):
+ self.strides = [_pair(stride) for stride in strides]
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
+ for stride in self.strides]
+ self.base_sizes = []
+ num_anchor_per_level = len(base_sizes[0])
+ for base_sizes_per_level in base_sizes:
+ assert num_anchor_per_level == len(base_sizes_per_level)
+ self.base_sizes.append(
+ [_pair(base_size) for base_size in base_sizes_per_level])
+ self.base_anchors = self.gen_base_anchors()
+
+ @property
+ def num_levels(self):
+ """int: number of feature levels that the generator will be applied"""
+ return len(self.base_sizes)
+
+ def gen_base_anchors(self):
+ """Generate base anchors.
+
+ Returns:
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
+ feature levels.
+ """
+ multi_level_base_anchors = []
+ for i, base_sizes_per_level in enumerate(self.base_sizes):
+ center = None
+ if self.centers is not None:
+ center = self.centers[i]
+ multi_level_base_anchors.append(
+ self.gen_single_level_base_anchors(base_sizes_per_level,
+ center))
+ return multi_level_base_anchors
+
+ def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
+ """Generate base anchors of a single level.
+
+ Args:
+ base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
+ anchors.
+ center (tuple[float], optional): The center of the base anchor
+ related to a single feature grid. Defaults to None.
+
+ Returns:
+ torch.Tensor: Anchors in a single-level feature maps.
+ """
+ x_center, y_center = center
+ base_anchors = []
+ for base_size in base_sizes_per_level:
+ w, h = base_size
+
+ # use float anchor and the anchor's center is aligned with the
+ # pixel center
+ base_anchor = torch.Tensor([
+ x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
+ y_center + 0.5 * h
+ ])
+ base_anchors.append(base_anchor)
+ base_anchors = torch.stack(base_anchors, dim=0)
+
+ return base_anchors
+
+ def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
+ """Generate responsible anchor flags of grid cells in multiple scales.
+
+ Args:
+ featmap_sizes (list(tuple)): List of feature map sizes in multiple
+ feature levels.
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
+ device (str): Device where the anchors will be put on.
+
+ Return:
+ list(torch.Tensor): responsible flags of anchors in multiple level
+ """
+ assert self.num_levels == len(featmap_sizes)
+ multi_level_responsible_flags = []
+ for i in range(self.num_levels):
+ anchor_stride = self.strides[i]
+ flags = self.single_level_responsible_flags(
+ featmap_sizes[i],
+ gt_bboxes,
+ anchor_stride,
+ self.num_base_anchors[i],
+ device=device)
+ multi_level_responsible_flags.append(flags)
+ return multi_level_responsible_flags
+
+ def single_level_responsible_flags(self,
+ featmap_size,
+ gt_bboxes,
+ stride,
+ num_base_anchors,
+ device='cuda'):
+ """Generate the responsible flags of anchor in a single feature map.
+
+ Args:
+ featmap_size (tuple[int]): The size of feature maps.
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
+ stride (tuple(int)): stride of current level
+ num_base_anchors (int): The number of base anchors.
+ device (str, optional): Device where the flags will be put on.
+ Defaults to 'cuda'.
+
+ Returns:
+ torch.Tensor: The valid flags of each anchor in a single level \
+ feature map.
+ """
+ feat_h, feat_w = featmap_size
+ gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
+ gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
+ gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
+ gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
+
+ # row major indexing
+ gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
+
+ responsible_grid = torch.zeros(
+ feat_h * feat_w, dtype=torch.uint8, device=device)
+ responsible_grid[gt_bboxes_grid_idx] = 1
+
+ responsible_grid = responsible_grid[:, None].expand(
+ responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
+ return responsible_grid
diff --git a/mmdet/core/anchor/builder.py b/mmdet/core/anchor/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d79b448ebca9f2b21d455046623172c48c5c3ef0
--- /dev/null
+++ b/mmdet/core/anchor/builder.py
@@ -0,0 +1,7 @@
+from mmcv.utils import Registry, build_from_cfg
+
+ANCHOR_GENERATORS = Registry('Anchor generator')
+
+
+def build_anchor_generator(cfg, default_args=None):
+ return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
diff --git a/mmdet/core/anchor/point_generator.py b/mmdet/core/anchor/point_generator.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6fbd988c317992c092c68c827dc4c53223b4a4a
--- /dev/null
+++ b/mmdet/core/anchor/point_generator.py
@@ -0,0 +1,37 @@
+import torch
+
+from .builder import ANCHOR_GENERATORS
+
+
+@ANCHOR_GENERATORS.register_module()
+class PointGenerator(object):
+
+ def _meshgrid(self, x, y, row_major=True):
+ xx = x.repeat(len(y))
+ yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
+ if row_major:
+ return xx, yy
+ else:
+ return yy, xx
+
+ def grid_points(self, featmap_size, stride=16, device='cuda'):
+ feat_h, feat_w = featmap_size
+ shift_x = torch.arange(0., feat_w, device=device) * stride
+ shift_y = torch.arange(0., feat_h, device=device) * stride
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
+ stride = shift_x.new_full((shift_xx.shape[0], ), stride)
+ shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
+ all_points = shifts.to(device)
+ return all_points
+
+ def valid_flags(self, featmap_size, valid_size, device='cuda'):
+ feat_h, feat_w = featmap_size
+ valid_h, valid_w = valid_size
+ assert valid_h <= feat_h and valid_w <= feat_w
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
+ valid_x[:valid_w] = 1
+ valid_y[:valid_h] = 1
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
+ valid = valid_xx & valid_yy
+ return valid
diff --git a/mmdet/core/anchor/utils.py b/mmdet/core/anchor/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9b53f37f7be1f52fe63c5e53df64ac1303b9e0
--- /dev/null
+++ b/mmdet/core/anchor/utils.py
@@ -0,0 +1,71 @@
+import torch
+
+
+def images_to_levels(target, num_levels):
+ """Convert targets by image to targets by feature level.
+
+ [target_img0, target_img1] -> [target_level0, target_level1, ...]
+ """
+ target = torch.stack(target, 0)
+ level_targets = []
+ start = 0
+ for n in num_levels:
+ end = start + n
+ # level_targets.append(target[:, start:end].squeeze(0))
+ level_targets.append(target[:, start:end])
+ start = end
+ return level_targets
+
+
+def anchor_inside_flags(flat_anchors,
+ valid_flags,
+ img_shape,
+ allowed_border=0):
+ """Check whether the anchors are inside the border.
+
+ Args:
+ flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
+ valid_flags (torch.Tensor): An existing valid flags of anchors.
+ img_shape (tuple(int)): Shape of current image.
+ allowed_border (int, optional): The border to allow the valid anchor.
+ Defaults to 0.
+
+ Returns:
+ torch.Tensor: Flags indicating whether the anchors are inside a \
+ valid range.
+ """
+ img_h, img_w = img_shape[:2]
+ if allowed_border >= 0:
+ inside_flags = valid_flags & \
+ (flat_anchors[:, 0] >= -allowed_border) & \
+ (flat_anchors[:, 1] >= -allowed_border) & \
+ (flat_anchors[:, 2] < img_w + allowed_border) & \
+ (flat_anchors[:, 3] < img_h + allowed_border)
+ else:
+ inside_flags = valid_flags
+ return inside_flags
+
+
+def calc_region(bbox, ratio, featmap_size=None):
+ """Calculate a proportional bbox region.
+
+ The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
+
+ Args:
+ bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
+ ratio (float): Ratio of the output region.
+ featmap_size (tuple): Feature map size used for clipping the boundary.
+
+ Returns:
+ tuple: x1, y1, x2, y2
+ """
+ x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
+ y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
+ x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
+ y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
+ if featmap_size is not None:
+ x1 = x1.clamp(min=0, max=featmap_size[1])
+ y1 = y1.clamp(min=0, max=featmap_size[0])
+ x2 = x2.clamp(min=0, max=featmap_size[1])
+ y2 = y2.clamp(min=0, max=featmap_size[0])
+ return (x1, y1, x2, y2)
diff --git a/mmdet/core/bbox/__init__.py b/mmdet/core/bbox/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3537297f57e4c3670afdb97b5fcb1b2d775e5f3
--- /dev/null
+++ b/mmdet/core/bbox/__init__.py
@@ -0,0 +1,27 @@
+from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
+ MaxIoUAssigner, RegionAssigner)
+from .builder import build_assigner, build_bbox_coder, build_sampler
+from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
+ TBLRBBoxCoder)
+from .iou_calculators import BboxOverlaps2D, bbox_overlaps
+from .samplers import (BaseSampler, CombinedSampler,
+ InstanceBalancedPosSampler, IoUBalancedNegSampler,
+ OHEMSampler, PseudoSampler, RandomSampler,
+ SamplingResult, ScoreHLRSampler)
+from .transforms import (bbox2distance, bbox2result, bbox2roi,
+ bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
+ bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
+ distance2bbox, roi2bbox)
+
+__all__ = [
+ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
+ 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
+ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
+ 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
+ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
+ 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
+ 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
+ 'RegionAssigner'
+]
diff --git a/mmdet/core/bbox/assigners/__init__.py b/mmdet/core/bbox/assigners/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..95e34a848652f2ab3ca6d3489aa2934d24817888
--- /dev/null
+++ b/mmdet/core/bbox/assigners/__init__.py
@@ -0,0 +1,16 @@
+from .approx_max_iou_assigner import ApproxMaxIoUAssigner
+from .assign_result import AssignResult
+from .atss_assigner import ATSSAssigner
+from .base_assigner import BaseAssigner
+from .center_region_assigner import CenterRegionAssigner
+from .grid_assigner import GridAssigner
+from .hungarian_assigner import HungarianAssigner
+from .max_iou_assigner import MaxIoUAssigner
+from .point_assigner import PointAssigner
+from .region_assigner import RegionAssigner
+
+__all__ = [
+ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
+ 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
+ 'HungarianAssigner', 'RegionAssigner'
+]
diff --git a/mmdet/core/bbox/assigners/approx_max_iou_assigner.py b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d07656d173744426795c81c14c6bcdb4e63a406
--- /dev/null
+++ b/mmdet/core/bbox/assigners/approx_max_iou_assigner.py
@@ -0,0 +1,145 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .max_iou_assigner import MaxIoUAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class ApproxMaxIoUAssigner(MaxIoUAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with an integer indicating the ground truth
+ index. (semi-positive index: gt label (0-based), -1: background)
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ match_low_quality=True,
+ gpu_assign_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.match_low_quality = match_low_quality
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self,
+ approxs,
+ squares,
+ approxs_per_octave,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None):
+ """Assign gt to approxs.
+
+ This method assign a gt bbox to each group of approxs (bboxes),
+ each group of approxs is represent by a base approx (bbox) and
+ will be assigned with -1, or a semi-positive number.
+ background_label (-1) means negative sample,
+ semi-positive number is the index (0-based) of assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to background_label (-1)
+ 2. use the max IoU of each group of approxs to assign
+ 2. assign proposals whose iou with all gts < neg_iou_thr to background
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals (may be more than
+ one) to itself
+
+ Args:
+ approxs (Tensor): Bounding boxes to be assigned,
+ shape(approxs_per_octave*n, 4).
+ squares (Tensor): Base Bounding boxes to be assigned,
+ shape(n, 4).
+ approxs_per_octave (int): number of approxs per octave
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_squares = squares.size(0)
+ num_gts = gt_bboxes.size(0)
+
+ if num_squares == 0 or num_gts == 0:
+ # No predictions and/or truth, return empty assignment
+ overlaps = approxs.new(num_gts, num_squares)
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ return assign_result
+
+ # re-organize anchors by approxs_per_octave x num_squares
+ approxs = torch.transpose(
+ approxs.view(num_squares, approxs_per_octave, 4), 0,
+ 1).contiguous().view(-1, 4)
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+ num_gts > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = approxs.device
+ approxs = approxs.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+ all_overlaps = self.iou_calculator(approxs, gt_bboxes)
+
+ overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
+ num_gts).max(dim=0)
+ overlaps = torch.transpose(overlaps, 0, 1)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ squares, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, squares, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+ return assign_result
diff --git a/mmdet/core/bbox/assigners/assign_result.py b/mmdet/core/bbox/assigners/assign_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..4639fbdba0a5b92778e1ab87d61182e54bfb9b6f
--- /dev/null
+++ b/mmdet/core/bbox/assigners/assign_result.py
@@ -0,0 +1,204 @@
+import torch
+
+from mmdet.utils import util_mixins
+
+
+class AssignResult(util_mixins.NiceRepr):
+ """Stores assignments between predicted and truth boxes.
+
+ Attributes:
+ num_gts (int): the number of truth boxes considered when computing this
+ assignment
+
+ gt_inds (LongTensor): for each predicted box indicates the 1-based
+ index of the assigned truth box. 0 means unassigned and -1 means
+ ignore.
+
+ max_overlaps (FloatTensor): the iou between the predicted box and its
+ assigned truth box.
+
+ labels (None | LongTensor): If specified, for each predicted box
+ indicates the category label of the assigned truth box.
+
+ Example:
+ >>> # An assign result between 4 predicted boxes and 9 true boxes
+ >>> # where only two boxes were assigned.
+ >>> num_gts = 9
+ >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
+ >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
+ >>> labels = torch.LongTensor([0, 3, 4, 0])
+ >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ >>> # Force addition of gt labels (when adding gt as proposals)
+ >>> new_labels = torch.LongTensor([3, 4, 5])
+ >>> self.add_gt_(new_labels)
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
+
+ """
+
+ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
+ self.num_gts = num_gts
+ self.gt_inds = gt_inds
+ self.max_overlaps = max_overlaps
+ self.labels = labels
+ # Interface for possible user-defined properties
+ self._extra_properties = {}
+
+ @property
+ def num_preds(self):
+ """int: the number of predictions in this assignment"""
+ return len(self.gt_inds)
+
+ def set_extra_property(self, key, value):
+ """Set user-defined new property."""
+ assert key not in self.info
+ self._extra_properties[key] = value
+
+ def get_extra_property(self, key):
+ """Get user-defined property."""
+ return self._extra_properties.get(key, None)
+
+ @property
+ def info(self):
+ """dict: a dictionary of info about the object"""
+ basic_info = {
+ 'num_gts': self.num_gts,
+ 'num_preds': self.num_preds,
+ 'gt_inds': self.gt_inds,
+ 'max_overlaps': self.max_overlaps,
+ 'labels': self.labels,
+ }
+ basic_info.update(self._extra_properties)
+ return basic_info
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this assign result"""
+ parts = []
+ parts.append(f'num_gts={self.num_gts!r}')
+ if self.gt_inds is None:
+ parts.append(f'gt_inds={self.gt_inds!r}')
+ else:
+ parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
+ if self.max_overlaps is None:
+ parts.append(f'max_overlaps={self.max_overlaps!r}')
+ else:
+ parts.append('max_overlaps.shape='
+ f'{tuple(self.max_overlaps.shape)!r}')
+ if self.labels is None:
+ parts.append(f'labels={self.labels!r}')
+ else:
+ parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
+ return ', '.join(parts)
+
+ @classmethod
+ def random(cls, **kwargs):
+ """Create random AssignResult for tests or debugging.
+
+ Args:
+ num_preds: number of predicted boxes
+ num_gts: number of true boxes
+ p_ignore (float): probability of a predicted box assinged to an
+ ignored truth
+ p_assigned (float): probability of a predicted box not being
+ assigned
+ p_use_label (float | bool): with labels or not
+ rng (None | int | numpy.random.RandomState): seed or state
+
+ Returns:
+ :obj:`AssignResult`: Randomly generated assign results.
+
+ Example:
+ >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
+ >>> self = AssignResult.random()
+ >>> print(self.info)
+ """
+ from mmdet.core.bbox import demodata
+ rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+ num_gts = kwargs.get('num_gts', None)
+ num_preds = kwargs.get('num_preds', None)
+ p_ignore = kwargs.get('p_ignore', 0.3)
+ p_assigned = kwargs.get('p_assigned', 0.7)
+ p_use_label = kwargs.get('p_use_label', 0.5)
+ num_classes = kwargs.get('p_use_label', 3)
+
+ if num_gts is None:
+ num_gts = rng.randint(0, 8)
+ if num_preds is None:
+ num_preds = rng.randint(0, 16)
+
+ if num_gts == 0:
+ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+ if p_use_label is True or p_use_label < rng.rand():
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = None
+ else:
+ import numpy as np
+ # Create an overlap for each predicted box
+ max_overlaps = torch.from_numpy(rng.rand(num_preds))
+
+ # Construct gt_inds for each predicted box
+ is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
+ # maximum number of assignments constraints
+ n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
+
+ assigned_idxs = np.where(is_assigned)[0]
+ rng.shuffle(assigned_idxs)
+ assigned_idxs = assigned_idxs[0:n_assigned]
+ assigned_idxs.sort()
+
+ is_assigned[:] = 0
+ is_assigned[assigned_idxs] = True
+
+ is_ignore = torch.from_numpy(
+ rng.rand(num_preds) < p_ignore) & is_assigned
+
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
+
+ true_idxs = np.arange(num_gts)
+ rng.shuffle(true_idxs)
+ true_idxs = torch.from_numpy(true_idxs)
+ gt_inds[is_assigned] = true_idxs[:n_assigned]
+
+ gt_inds = torch.from_numpy(
+ rng.randint(1, num_gts + 1, size=num_preds))
+ gt_inds[is_ignore] = -1
+ gt_inds[~is_assigned] = 0
+ max_overlaps[~is_assigned] = 0
+
+ if p_use_label is True or p_use_label < rng.rand():
+ if num_classes == 0:
+ labels = torch.zeros(num_preds, dtype=torch.int64)
+ else:
+ labels = torch.from_numpy(
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ rng.randint(0, num_classes, size=num_preds))
+ labels[~is_assigned] = 0
+ else:
+ labels = None
+
+ self = cls(num_gts, gt_inds, max_overlaps, labels)
+ return self
+
+ def add_gt_(self, gt_labels):
+ """Add ground truth as assigned results.
+
+ Args:
+ gt_labels (torch.Tensor): Labels of gt boxes
+ """
+ self_inds = torch.arange(
+ 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
+ self.gt_inds = torch.cat([self_inds, self.gt_inds])
+
+ self.max_overlaps = torch.cat(
+ [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
+
+ if self.labels is not None:
+ self.labels = torch.cat([gt_labels, self.labels])
diff --git a/mmdet/core/bbox/assigners/atss_assigner.py b/mmdet/core/bbox/assigners/atss_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4fe9d0e3c8704bd780d493eff20a5505dbe9580
--- /dev/null
+++ b/mmdet/core/bbox/assigners/atss_assigner.py
@@ -0,0 +1,178 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class ATSSAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `0` or a positive integer
+ indicating the ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ topk (float): number of bbox selected in each level
+ """
+
+ def __init__(self,
+ topk,
+ iou_calculator=dict(type='BboxOverlaps2D'),
+ ignore_iof_thr=-1):
+ self.topk = topk
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+ self.ignore_iof_thr = ignore_iof_thr
+
+ # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
+
+ def assign(self,
+ bboxes,
+ num_level_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ gt_labels=None):
+ """Assign gt to bboxes.
+
+ The assignment is done in following steps
+
+ 1. compute iou between all bbox (bbox of all pyramid levels) and gt
+ 2. compute center distance between all bbox and gt
+ 3. on each pyramid level, for each gt, select k bbox whose center
+ are closest to the gt center, so we total select k*l bbox as
+ candidates for each gt
+ 4. get corresponding iou for the these candidates, and compute the
+ mean and std, set mean + std as the iou threshold
+ 5. select these candidates whose iou are greater than or equal to
+ the threshold as positive
+ 6. limit the positive sample's center in gt
+
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ num_level_bboxes (List): num of bboxes in each level
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ INF = 100000000
+ bboxes = bboxes[:, :4]
+ num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+
+ # compute iou between all bbox and gt
+ overlaps = self.iou_calculator(bboxes, gt_bboxes)
+
+ # assign 0 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ 0,
+ dtype=torch.long)
+
+ if num_gt == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gt == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
+
+ # compute center distance between all bbox and gt
+ gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
+ gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
+ gt_points = torch.stack((gt_cx, gt_cy), dim=1)
+
+ bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
+ bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
+ bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
+
+ distances = (bboxes_points[:, None, :] -
+ gt_points[None, :, :]).pow(2).sum(-1).sqrt()
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
+ distances[ignore_idxs, :] = INF
+ assigned_gt_inds[ignore_idxs] = -1
+
+ # Selecting candidates based on the center distance
+ candidate_idxs = []
+ start_idx = 0
+ for level, bboxes_per_level in enumerate(num_level_bboxes):
+ # on each pyramid level, for each gt,
+ # select k bbox whose center are closest to the gt center
+ end_idx = start_idx + bboxes_per_level
+ distances_per_level = distances[start_idx:end_idx, :]
+ selectable_k = min(self.topk, bboxes_per_level)
+ _, topk_idxs_per_level = distances_per_level.topk(
+ selectable_k, dim=0, largest=False)
+ candidate_idxs.append(topk_idxs_per_level + start_idx)
+ start_idx = end_idx
+ candidate_idxs = torch.cat(candidate_idxs, dim=0)
+
+ # get corresponding iou for the these candidates, and compute the
+ # mean and std, set mean + std as the iou threshold
+ candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
+ overlaps_mean_per_gt = candidate_overlaps.mean(0)
+ overlaps_std_per_gt = candidate_overlaps.std(0)
+ overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
+
+ is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
+
+ # limit the positive sample's center in gt
+ for gt_idx in range(num_gt):
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
+ ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
+ num_gt, num_bboxes).contiguous().view(-1)
+ candidate_idxs = candidate_idxs.view(-1)
+
+ # calculate the left, top, right, bottom distance between positive
+ # bbox center and gt side
+ l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
+ t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
+ r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
+ b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
+ is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
+ is_pos = is_pos & is_in_gts
+
+ # if an anchor box is assigned to multiple gts,
+ # the one with the highest IoU will be selected.
+ overlaps_inf = torch.full_like(overlaps,
+ -INF).t().contiguous().view(-1)
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
+
+ max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
+ assigned_gt_inds[
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+ return AssignResult(
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/base_assigner.py b/mmdet/core/bbox/assigners/base_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ff0160dbb4bfbf53cb40d1d5cb29bcc3d197a59
--- /dev/null
+++ b/mmdet/core/bbox/assigners/base_assigner.py
@@ -0,0 +1,9 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BaseAssigner(metaclass=ABCMeta):
+ """Base assigner that assigns boxes to ground truth boxes."""
+
+ @abstractmethod
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
diff --git a/mmdet/core/bbox/assigners/center_region_assigner.py b/mmdet/core/bbox/assigners/center_region_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..488e3b615318787751cab3211e38dd9471c666be
--- /dev/null
+++ b/mmdet/core/bbox/assigners/center_region_assigner.py
@@ -0,0 +1,335 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+def scale_boxes(bboxes, scale):
+ """Expand an array of boxes by a given scale.
+
+ Args:
+ bboxes (Tensor): Shape (m, 4)
+ scale (float): The scale factor of bboxes
+
+ Returns:
+ (Tensor): Shape (m, 4). Scaled bboxes
+ """
+ assert bboxes.size(1) == 4
+ w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
+ h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
+ x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
+ y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
+
+ w_half *= scale
+ h_half *= scale
+
+ boxes_scaled = torch.zeros_like(bboxes)
+ boxes_scaled[:, 0] = x_c - w_half
+ boxes_scaled[:, 2] = x_c + w_half
+ boxes_scaled[:, 1] = y_c - h_half
+ boxes_scaled[:, 3] = y_c + h_half
+ return boxes_scaled
+
+
+def is_located_in(points, bboxes):
+ """Are points located in bboxes.
+
+ Args:
+ points (Tensor): Points, shape: (m, 2).
+ bboxes (Tensor): Bounding boxes, shape: (n, 4).
+
+ Return:
+ Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
+ """
+ assert points.size(1) == 2
+ assert bboxes.size(1) == 4
+ return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
+ (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
+ (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
+ (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
+
+
+def bboxes_area(bboxes):
+ """Compute the area of an array of bboxes.
+
+ Args:
+ bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
+
+ Returns:
+ Tensor: Area of the bboxes. Shape: (m, )
+ """
+ assert bboxes.size(1) == 4
+ w = (bboxes[:, 2] - bboxes[:, 0])
+ h = (bboxes[:, 3] - bboxes[:, 1])
+ areas = w * h
+ return areas
+
+
+@BBOX_ASSIGNERS.register_module()
+class CenterRegionAssigner(BaseAssigner):
+ """Assign pixels at the center region of a bbox as positive.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+ - -1: negative samples
+ - semi-positive numbers: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_scale (float): Threshold within which pixels are
+ labelled as positive.
+ neg_scale (float): Threshold above which pixels are
+ labelled as positive.
+ min_pos_iof (float): Minimum iof of a pixel with a gt to be
+ labelled as positive. Default: 1e-2
+ ignore_gt_scale (float): Threshold within which the pixels
+ are ignored when the gt is labelled as shadowed. Default: 0.5
+ foreground_dominate (bool): If True, the bbox will be assigned as
+ positive when a gt's kernel region overlaps with another's shadowed
+ (ignored) region, otherwise it is set as ignored. Default to False.
+ """
+
+ def __init__(self,
+ pos_scale,
+ neg_scale,
+ min_pos_iof=1e-2,
+ ignore_gt_scale=0.5,
+ foreground_dominate=False,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_scale = pos_scale
+ self.neg_scale = neg_scale
+ self.min_pos_iof = min_pos_iof
+ self.ignore_gt_scale = ignore_gt_scale
+ self.foreground_dominate = foreground_dominate
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def get_gt_priorities(self, gt_bboxes):
+ """Get gt priorities according to their areas.
+
+ Smaller gt has higher priority.
+
+ Args:
+ gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
+
+ Returns:
+ Tensor: The priority of gts so that gts with larger priority is \
+ more likely to be assigned. Shape (k, )
+ """
+ gt_areas = bboxes_area(gt_bboxes)
+ # Rank all gt bbox areas. Smaller objects has larger priority
+ _, sort_idx = gt_areas.sort(descending=True)
+ sort_idx = sort_idx.argsort()
+ return sort_idx
+
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+
+ This method assigns gts to every bbox (proposal/anchor), each bbox \
+ will be assigned with -1, or a semi-positive number. -1 means \
+ negative sample, semi-positive number is the index (0-based) of \
+ assigned gt.
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
+
+ Returns:
+ :obj:`AssignResult`: The assigned result. Note that \
+ shadowed_labels of shape (N, 2) is also added as an \
+ `assign_result` attribute. `shadowed_labels` is a tensor \
+ composed of N pairs of anchor_ind, class_label], where N \
+ is the number of anchors that lie in the outer region of a \
+ gt, anchor_ind is the shadowed anchor index and class_label \
+ is the shadowed class label.
+
+ Example:
+ >>> self = CenterRegionAssigner(0.2, 0.2)
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
+ """
+ # There are in total 5 steps in the pixel assignment
+ # 1. Find core (the center region, say inner 0.2)
+ # and shadow (the relatively ourter part, say inner 0.2-0.5)
+ # regions of every gt.
+ # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
+ # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
+ # the image.
+ # 3.1. For overlapping objects, the prior bboxes in gt_core is
+ # assigned with the object with smallest area
+ # 4. Assign prior bboxes with class label according to its gt id.
+ # 4.1. Assign -1 to prior bboxes lying in shadowed gts
+ # 4.2. Assign positive prior boxes with the corresponding label
+ # 5. Find pixels lying in the shadow of an object and assign them with
+ # background label, but set the loss weight of its corresponding
+ # gt to zero.
+ assert bboxes.size(1) == 4, 'bboxes must have size of 4'
+ # 1. Find core positive and shadow region of every gt
+ gt_core = scale_boxes(gt_bboxes, self.pos_scale)
+ gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
+
+ # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
+ bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
+ # The center points lie within the gt boxes
+ is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
+ # Only calculate bbox and gt_core IoF. This enables small prior bboxes
+ # to match large gts
+ bbox_and_gt_core_overlaps = self.iou_calculator(
+ bboxes, gt_core, mode='iof')
+ # The center point of effective priors should be within the gt box
+ is_bbox_in_gt_core = is_bbox_in_gt & (
+ bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
+
+ is_bbox_in_gt_shadow = (
+ self.iou_calculator(bboxes, gt_shadow, mode='iof') >
+ self.min_pos_iof)
+ # Rule out center effective positive pixels
+ is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
+
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+ if num_gts == 0 or num_bboxes == 0:
+ # If no gts exist, assign all pixels to negative
+ assigned_gt_ids = \
+ is_bbox_in_gt_core.new_zeros((num_bboxes,),
+ dtype=torch.long)
+ pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
+ else:
+ # Step 3: assign a one-hot gt id to each pixel, and smaller objects
+ # have high priority to assign the pixel.
+ sort_idx = self.get_gt_priorities(gt_bboxes)
+ assigned_gt_ids, pixels_in_gt_shadow = \
+ self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
+ is_bbox_in_gt_shadow,
+ gt_priority=sort_idx)
+
+ if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
+ # No ground truth or boxes, return empty assignment
+ gt_bboxes_ignore = scale_boxes(
+ gt_bboxes_ignore, scale=self.ignore_gt_scale)
+ is_bbox_in_ignored_gts = is_located_in(bbox_centers,
+ gt_bboxes_ignore)
+ is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
+ assigned_gt_ids[is_bbox_in_ignored_gts] = -1
+
+ # 4. Assign prior bboxes with class label according to its gt id.
+ assigned_labels = None
+ shadowed_pixel_labels = None
+ if gt_labels is not None:
+ # Default assigned label is the background (-1)
+ assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_ids > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
+ - 1]
+ # 5. Find pixels lying in the shadow of an object
+ shadowed_pixel_labels = pixels_in_gt_shadow.clone()
+ if pixels_in_gt_shadow.numel() > 0:
+ pixel_idx, gt_idx =\
+ pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
+ assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
+ 'Some pixels are dually assigned to ignore and gt!'
+ shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
+ override = (
+ assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
+ if self.foreground_dominate:
+ # When a pixel is both positive and shadowed, set it as pos
+ shadowed_pixel_labels = shadowed_pixel_labels[~override]
+ else:
+ # When a pixel is both pos and shadowed, set it as shadowed
+ assigned_labels[pixel_idx[override]] = -1
+ assigned_gt_ids[pixel_idx[override]] = 0
+
+ assign_result = AssignResult(
+ num_gts, assigned_gt_ids, None, labels=assigned_labels)
+ # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
+ assign_result.set_extra_property('shadowed_labels',
+ shadowed_pixel_labels)
+ return assign_result
+
+ def assign_one_hot_gt_indices(self,
+ is_bbox_in_gt_core,
+ is_bbox_in_gt_shadow,
+ gt_priority=None):
+ """Assign only one gt index to each prior box.
+
+ Gts with large gt_priority are more likely to be assigned.
+
+ Args:
+ is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
+ is in the core area of a gt (e.g. 0-0.2).
+ Shape: (num_prior, num_gt).
+ is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
+ center is in the shadowed area of a gt (e.g. 0.2-0.5).
+ Shape: (num_prior, num_gt).
+ gt_priority (Tensor): Priorities of gts. The gt with a higher
+ priority is more likely to be assigned to the bbox when the bbox
+ match with multiple gts. Shape: (num_gt, ).
+
+ Returns:
+ tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
+
+ - assigned_gt_inds: The assigned gt index of each prior bbox \
+ (i.e. index from 1 to num_gts). Shape: (num_prior, ).
+ - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
+ shape (num_ignore, 2) with first column being the \
+ shadowed prior bbox indices and the second column the \
+ shadowed gt indices (1-based).
+ """
+ num_bboxes, num_gts = is_bbox_in_gt_core.shape
+
+ if gt_priority is None:
+ gt_priority = torch.arange(
+ num_gts, device=is_bbox_in_gt_core.device)
+ assert gt_priority.size(0) == num_gts
+ # The bigger gt_priority, the more preferable to be assigned
+ # The assigned inds are by default 0 (background)
+ assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
+ dtype=torch.long)
+ # Shadowed bboxes are assigned to be background. But the corresponding
+ # label is ignored during loss calculation, which is done through
+ # shadowed_gt_inds
+ shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
+ if is_bbox_in_gt_core.sum() == 0: # No gt match
+ shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
+ return assigned_gt_inds, shadowed_gt_inds
+
+ # The priority of each prior box and gt pair. If one prior box is
+ # matched bo multiple gts. Only the pair with the highest priority
+ # is saved
+ pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
+ -1,
+ dtype=torch.long)
+
+ # Each bbox could match with multiple gts.
+ # The following codes deal with this situation
+ # Matched bboxes (to any gt). Shape: (num_pos_anchor, )
+ inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
+ # The matched gt index of each positive bbox. Length >= num_pos_anchor
+ # , since one bbox could match multiple gts
+ matched_bbox_gt_inds = torch.nonzero(
+ is_bbox_in_gt_core, as_tuple=False)[:, 1]
+ # Assign priority to each bbox-gt pair.
+ pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
+ _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
+ assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
+ # Zero-out the assigned anchor box to filter the shadowed gt indices
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
+ # Concat the shadowed indices due to overlapping with that out side of
+ # effective scale. shape: (total_num_ignore, 2)
+ shadowed_gt_inds = torch.cat(
+ (shadowed_gt_inds, torch.nonzero(
+ is_bbox_in_gt_core, as_tuple=False)),
+ dim=0)
+ # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
+ # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
+ if shadowed_gt_inds.numel() > 0:
+ shadowed_gt_inds[:, 1] += 1
+ return assigned_gt_inds, shadowed_gt_inds
diff --git a/mmdet/core/bbox/assigners/grid_assigner.py b/mmdet/core/bbox/assigners/grid_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..7390ea6370639c939d578c6ebf0f9268499161bc
--- /dev/null
+++ b/mmdet/core/bbox/assigners/grid_assigner.py
@@ -0,0 +1,155 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class GridAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+
+ - -1: don't care
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
+ """Assign gt to bboxes. The process is very much like the max iou
+ assigner, except that positive samples are constrained within the cell
+ that the gt boxes fell in.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
+ 0 means negative sample, positive number is the index (1-based) of
+ assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to -1
+ 2. assign proposals whose iou with all gts <= neg_iou_thr to 0
+ 3. for each bbox within a cell, if the iou with its nearest gt >
+ pos_iou_thr and the center of that gt falls inside the cell,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals within the cell the
+ gt bbox falls in to itself.
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ box_responsible_flags (Tensor): flag to indicate whether box is
+ responsible for prediction, shape(n, )
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
+
+ # compute iou between all gt and bboxes
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # 2. assign negative: below
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ # shape of max_overlaps == argmax_overlaps == num_bboxes
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ if isinstance(self.neg_iou_thr, float):
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps <= self.neg_iou_thr)] = 0
+ elif isinstance(self.neg_iou_thr, (tuple, list)):
+ assert len(self.neg_iou_thr) == 2
+ assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
+ & (max_overlaps <= self.neg_iou_thr[1])] = 0
+
+ # 3. assign positive: falls into responsible cell and above
+ # positive IOU threshold, the order matters.
+ # the prior condition of comparision is to filter out all
+ # unrelated anchors, i.e. not box_responsible_flags
+ overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
+
+ # calculate max_overlaps again, but this time we only consider IOUs
+ # for anchors responsible for prediction
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+
+ # for each gt, which anchor best overlaps with it
+ # for each gt, the max iou of all proposals
+ # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
+
+ pos_inds = (max_overlaps >
+ self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
+
+ # 4. assign positive to max overlapped anchors within responsible cell
+ for i in range(num_gts):
+ if gt_max_overlaps[i] > self.min_pos_iou:
+ if self.gt_max_assign_all:
+ max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
+ box_responsible_flags.type(torch.bool)
+ assigned_gt_inds[max_iou_inds] = i + 1
+ elif box_responsible_flags[gt_argmax_overlaps[i]]:
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
+
+ # assign labels of positive anchors
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/hungarian_assigner.py b/mmdet/core/bbox/assigners/hungarian_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..e10cc14afac4ddfcb9395c1a250ece1fbfe3263c
--- /dev/null
+++ b/mmdet/core/bbox/assigners/hungarian_assigner.py
@@ -0,0 +1,145 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..match_costs import build_match_cost
+from ..transforms import bbox_cxcywh_to_xyxy
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+try:
+ from scipy.optimize import linear_sum_assignment
+except ImportError:
+ linear_sum_assignment = None
+
+
+@BBOX_ASSIGNERS.register_module()
+class HungarianAssigner(BaseAssigner):
+ """Computes one-to-one matching between predictions and ground truth.
+
+ This class computes an assignment between the targets and the predictions
+ based on the costs. The costs are weighted sum of three components:
+ classification cost, regression L1 cost and regression iou cost. The
+ targets don't include the no_object, so generally there are more
+ predictions than targets. After the one-to-one matching, the un-matched
+ are treated as backgrounds. Thus each query prediction will be assigned
+ with `0` or a positive integer indicating the ground truth index:
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ cls_weight (int | float, optional): The scale factor for classification
+ cost. Default 1.0.
+ bbox_weight (int | float, optional): The scale factor for regression
+ L1 cost. Default 1.0.
+ iou_weight (int | float, optional): The scale factor for regression
+ iou cost. Default 1.0.
+ iou_calculator (dict | optional): The config for the iou calculation.
+ Default type `BboxOverlaps2D`.
+ iou_mode (str | optional): "iou" (intersection over union), "iof"
+ (intersection over foreground), or "giou" (generalized
+ intersection over union). Default "giou".
+ """
+
+ def __init__(self,
+ cls_cost=dict(type='ClassificationCost', weight=1.),
+ reg_cost=dict(type='BBoxL1Cost', weight=1.0),
+ iou_cost=dict(type='IoUCost', iou_mode='giou', weight=1.0)):
+ self.cls_cost = build_match_cost(cls_cost)
+ self.reg_cost = build_match_cost(reg_cost)
+ self.iou_cost = build_match_cost(iou_cost)
+
+ def assign(self,
+ bbox_pred,
+ cls_pred,
+ gt_bboxes,
+ gt_labels,
+ img_meta,
+ gt_bboxes_ignore=None,
+ eps=1e-7):
+ """Computes one-to-one matching based on the weighted costs.
+
+ This method assign each query prediction to a ground truth or
+ background. The `assigned_gt_inds` with -1 means don't care,
+ 0 means negative sample, and positive number is the index (1-based)
+ of assigned gt.
+ The assignment is done in the following steps, the order matters.
+
+ 1. assign every prediction to -1
+ 2. compute the weighted costs
+ 3. do Hungarian matching on CPU based on the costs
+ 4. assign all to 0 (background) first, then for each matched pair
+ between predictions and gts, treat this prediction as foreground
+ and assign the corresponding gt index (plus 1) to it.
+
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+ img_meta (dict): Meta information for current image.
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`. Default None.
+ eps (int | float, optional): A value added to the denominator for
+ numerical stability. Default 1e-7.
+
+ Returns:
+ :obj:`AssignResult`: The assigned result.
+ """
+ assert gt_bboxes_ignore is None, \
+ 'Only case when gt_bboxes_ignore is None is supported.'
+ num_gts, num_bboxes = gt_bboxes.size(0), bbox_pred.size(0)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ assigned_labels = bbox_pred.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ if num_gts == 0:
+ # No ground truth, assign all to background
+ assigned_gt_inds[:] = 0
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+ img_h, img_w, _ = img_meta['img_shape']
+ factor = gt_bboxes.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0)
+
+ # 2. compute the weighted costs
+ # classification and bboxcost.
+ cls_cost = self.cls_cost(cls_pred, gt_labels)
+ # regression L1 cost
+ normalize_gt_bboxes = gt_bboxes / factor
+ reg_cost = self.reg_cost(bbox_pred, normalize_gt_bboxes)
+ # regression iou cost, defaultly giou is used in official DETR.
+ bboxes = bbox_cxcywh_to_xyxy(bbox_pred) * factor
+ iou_cost = self.iou_cost(bboxes, gt_bboxes)
+ # weighted sum of above three costs
+ cost = cls_cost + reg_cost + iou_cost
+
+ # 3. do Hungarian matching on CPU using linear_sum_assignment
+ cost = cost.detach().cpu()
+ if linear_sum_assignment is None:
+ raise ImportError('Please run "pip install scipy" '
+ 'to install scipy first.')
+ matched_row_inds, matched_col_inds = linear_sum_assignment(cost)
+ matched_row_inds = torch.from_numpy(matched_row_inds).to(
+ bbox_pred.device)
+ matched_col_inds = torch.from_numpy(matched_col_inds).to(
+ bbox_pred.device)
+
+ # 4. assign backgrounds and foregrounds
+ # assign all indices to backgrounds first
+ assigned_gt_inds[:] = 0
+ # assign foregrounds based on matching results
+ assigned_gt_inds[matched_row_inds] = matched_col_inds + 1
+ assigned_labels[matched_row_inds] = gt_labels[matched_col_inds]
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/max_iou_assigner.py b/mmdet/core/bbox/assigners/max_iou_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cf4c4b4b450f87dfb99c3d33d8ed83d3e5cfcb3
--- /dev/null
+++ b/mmdet/core/bbox/assigners/max_iou_assigner.py
@@ -0,0 +1,212 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from ..iou_calculators import build_iou_calculator
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class MaxIoUAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, or a semi-positive integer
+ indicating the ground truth index.
+
+ - -1: negative sample, no assigned gt
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
+
+ Args:
+ pos_iou_thr (float): IoU threshold for positive bboxes.
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
+ positive bbox. Positive samples can have smaller IoU than
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
+ highest overlap with some gt to that gt.
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
+ `gt_bboxes_ignore` is specified). Negative values mean not
+ ignoring any bboxes.
+ ignore_wrt_candidates (bool): Whether to compute the iof between
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
+ match_low_quality (bool): Whether to allow low quality matches. This is
+ usually allowed for RPN and single stage detectors, but not allowed
+ in the second stage. Details are demonstrated in Step 4.
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
+ assign. When the number of gt is above this threshold, will assign
+ on CPU device. Negative values mean not assign on CPU.
+ """
+
+ def __init__(self,
+ pos_iou_thr,
+ neg_iou_thr,
+ min_pos_iou=.0,
+ gt_max_assign_all=True,
+ ignore_iof_thr=-1,
+ ignore_wrt_candidates=True,
+ match_low_quality=True,
+ gpu_assign_thr=-1,
+ iou_calculator=dict(type='BboxOverlaps2D')):
+ self.pos_iou_thr = pos_iou_thr
+ self.neg_iou_thr = neg_iou_thr
+ self.min_pos_iou = min_pos_iou
+ self.gt_max_assign_all = gt_max_assign_all
+ self.ignore_iof_thr = ignore_iof_thr
+ self.ignore_wrt_candidates = ignore_wrt_candidates
+ self.gpu_assign_thr = gpu_assign_thr
+ self.match_low_quality = match_low_quality
+ self.iou_calculator = build_iou_calculator(iou_calculator)
+
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to bboxes.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, or a semi-positive number. -1 means negative
+ sample, semi-positive number is the index (0-based) of assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every bbox to the background
+ 2. assign proposals whose iou with all gts < neg_iou_thr to 0
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
+ assign it to that bbox
+ 4. for each gt bbox, assign its nearest proposals (may be more than
+ one) to itself
+
+ Args:
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+
+ Example:
+ >>> self = MaxIoUAssigner(0.5, 0.5)
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 9]])
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
+ """
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
+ gt_bboxes.shape[0] > self.gpu_assign_thr) else False
+ # compute overlap and assign gt on CPU when number of GT is large
+ if assign_on_cpu:
+ device = bboxes.device
+ bboxes = bboxes.cpu()
+ gt_bboxes = gt_bboxes.cpu()
+ if gt_bboxes_ignore is not None:
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
+ if gt_labels is not None:
+ gt_labels = gt_labels.cpu()
+
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
+
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
+ if self.ignore_wrt_candidates:
+ ignore_overlaps = self.iou_calculator(
+ bboxes, gt_bboxes_ignore, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
+ else:
+ ignore_overlaps = self.iou_calculator(
+ gt_bboxes_ignore, bboxes, mode='iof')
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
+
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
+ if assign_on_cpu:
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
+ if assign_result.labels is not None:
+ assign_result.labels = assign_result.labels.to(device)
+ return assign_result
+
+ def assign_wrt_overlaps(self, overlaps, gt_labels=None):
+ """Assign w.r.t. the overlaps of bboxes with gts.
+
+ Args:
+ overlaps (Tensor): Overlaps between k gt_bboxes and n bboxes,
+ shape(k, n).
+ gt_labels (Tensor, optional): Labels of k gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_gts, num_bboxes = overlaps.size(0), overlaps.size(1)
+
+ # 1. assign -1 by default
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
+ if num_gts == 0:
+ # No truth, assign everything to background
+ assigned_gt_inds[:] = 0
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = overlaps.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ # for each anchor, which gt best overlaps with it
+ # for each anchor, the max iou of all gts
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
+ # for each gt, which anchor best overlaps with it
+ # for each gt, the max iou of all proposals
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
+
+ # 2. assign negative: below
+ # the negative inds are set to be 0
+ if isinstance(self.neg_iou_thr, float):
+ assigned_gt_inds[(max_overlaps >= 0)
+ & (max_overlaps < self.neg_iou_thr)] = 0
+ elif isinstance(self.neg_iou_thr, tuple):
+ assert len(self.neg_iou_thr) == 2
+ assigned_gt_inds[(max_overlaps >= self.neg_iou_thr[0])
+ & (max_overlaps < self.neg_iou_thr[1])] = 0
+
+ # 3. assign positive: above positive IoU threshold
+ pos_inds = max_overlaps >= self.pos_iou_thr
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
+
+ if self.match_low_quality:
+ # Low-quality matching will overwrite the assigned_gt_inds assigned
+ # in Step 3. Thus, the assigned gt might not be the best one for
+ # prediction.
+ # For example, if bbox A has 0.9 and 0.8 iou with GT bbox 1 & 2,
+ # bbox 1 will be assigned as the best target for bbox A in step 3.
+ # However, if GT bbox 2's gt_argmax_overlaps = A, bbox A's
+ # assigned_gt_inds will be overwritten to be bbox B.
+ # This might be the reason that it is not used in ROI Heads.
+ for i in range(num_gts):
+ if gt_max_overlaps[i] >= self.min_pos_iou:
+ if self.gt_max_assign_all:
+ max_iou_inds = overlaps[i, :] == gt_max_overlaps[i]
+ assigned_gt_inds[max_iou_inds] = i + 1
+ else:
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/point_assigner.py b/mmdet/core/bbox/assigners/point_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb8f5e4edc63f4851e2067034c5e67a3558f31bc
--- /dev/null
+++ b/mmdet/core/bbox/assigners/point_assigner.py
@@ -0,0 +1,133 @@
+import torch
+
+from ..builder import BBOX_ASSIGNERS
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+@BBOX_ASSIGNERS.register_module()
+class PointAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each point.
+
+ Each proposals will be assigned with `0`, or a positive integer
+ indicating the ground truth index.
+
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+ """
+
+ def __init__(self, scale=4, pos_num=3):
+ self.scale = scale
+ self.pos_num = pos_num
+
+ def assign(self, points, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
+ """Assign gt to points.
+
+ This method assign a gt bbox to every points set, each points set
+ will be assigned with the background_label (-1), or a label number.
+ -1 is background, and semi-positive number is the index (0-based) of
+ assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. assign every points to the background_label (-1)
+ 2. A point is assigned to some gt bbox if
+ (i) the point is within the k closest points to the gt bbox
+ (ii) the distance between this point and the gt is smaller than
+ other gt bboxes
+
+ Args:
+ points (Tensor): points to be assigned, shape(n, 3) while last
+ dimension stands for (x, y, stride).
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ NOTE: currently unused.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ num_points = points.shape[0]
+ num_gts = gt_bboxes.shape[0]
+
+ if num_gts == 0 or num_points == 0:
+ # If no truth assign everything to the background
+ assigned_gt_inds = points.new_full((num_points, ),
+ 0,
+ dtype=torch.long)
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = points.new_full((num_points, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
+
+ points_xy = points[:, :2]
+ points_stride = points[:, 2]
+ points_lvl = torch.log2(
+ points_stride).int() # [3...,4...,5...,6...,7...]
+ lvl_min, lvl_max = points_lvl.min(), points_lvl.max()
+
+ # assign gt box
+ gt_bboxes_xy = (gt_bboxes[:, :2] + gt_bboxes[:, 2:]) / 2
+ gt_bboxes_wh = (gt_bboxes[:, 2:] - gt_bboxes[:, :2]).clamp(min=1e-6)
+ scale = self.scale
+ gt_bboxes_lvl = ((torch.log2(gt_bboxes_wh[:, 0] / scale) +
+ torch.log2(gt_bboxes_wh[:, 1] / scale)) / 2).int()
+ gt_bboxes_lvl = torch.clamp(gt_bboxes_lvl, min=lvl_min, max=lvl_max)
+
+ # stores the assigned gt index of each point
+ assigned_gt_inds = points.new_zeros((num_points, ), dtype=torch.long)
+ # stores the assigned gt dist (to this point) of each point
+ assigned_gt_dist = points.new_full((num_points, ), float('inf'))
+ points_range = torch.arange(points.shape[0])
+
+ for idx in range(num_gts):
+ gt_lvl = gt_bboxes_lvl[idx]
+ # get the index of points in this level
+ lvl_idx = gt_lvl == points_lvl
+ points_index = points_range[lvl_idx]
+ # get the points in this level
+ lvl_points = points_xy[lvl_idx, :]
+ # get the center point of gt
+ gt_point = gt_bboxes_xy[[idx], :]
+ # get width and height of gt
+ gt_wh = gt_bboxes_wh[[idx], :]
+ # compute the distance between gt center and
+ # all points in this level
+ points_gt_dist = ((lvl_points - gt_point) / gt_wh).norm(dim=1)
+ # find the nearest k points to gt center in this level
+ min_dist, min_dist_index = torch.topk(
+ points_gt_dist, self.pos_num, largest=False)
+ # the index of nearest k points to gt center in this level
+ min_dist_points_index = points_index[min_dist_index]
+ # The less_than_recorded_index stores the index
+ # of min_dist that is less then the assigned_gt_dist. Where
+ # assigned_gt_dist stores the dist from previous assigned gt
+ # (if exist) to each point.
+ less_than_recorded_index = min_dist < assigned_gt_dist[
+ min_dist_points_index]
+ # The min_dist_points_index stores the index of points satisfy:
+ # (1) it is k nearest to current gt center in this level.
+ # (2) it is closer to current gt center than other gt center.
+ min_dist_points_index = min_dist_points_index[
+ less_than_recorded_index]
+ # assign the result
+ assigned_gt_inds[min_dist_points_index] = idx + 1
+ assigned_gt_dist[min_dist_points_index] = min_dist[
+ less_than_recorded_index]
+
+ if gt_labels is not None:
+ assigned_labels = assigned_gt_inds.new_full((num_points, ), -1)
+ pos_inds = torch.nonzero(
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
+ if pos_inds.numel() > 0:
+ assigned_labels[pos_inds] = gt_labels[
+ assigned_gt_inds[pos_inds] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/assigners/region_assigner.py b/mmdet/core/bbox/assigners/region_assigner.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e8464b97c8d8f44488d7bb781ca2e733a258e55
--- /dev/null
+++ b/mmdet/core/bbox/assigners/region_assigner.py
@@ -0,0 +1,221 @@
+import torch
+
+from mmdet.core import anchor_inside_flags
+from ..builder import BBOX_ASSIGNERS
+from .assign_result import AssignResult
+from .base_assigner import BaseAssigner
+
+
+def calc_region(bbox, ratio, stride, featmap_size=None):
+ """Calculate region of the box defined by the ratio, the ratio is from the
+ center of the box to every edge."""
+ # project bbox on the feature
+ f_bbox = bbox / stride
+ x1 = torch.round((1 - ratio) * f_bbox[0] + ratio * f_bbox[2])
+ y1 = torch.round((1 - ratio) * f_bbox[1] + ratio * f_bbox[3])
+ x2 = torch.round(ratio * f_bbox[0] + (1 - ratio) * f_bbox[2])
+ y2 = torch.round(ratio * f_bbox[1] + (1 - ratio) * f_bbox[3])
+ if featmap_size is not None:
+ x1 = x1.clamp(min=0, max=featmap_size[1])
+ y1 = y1.clamp(min=0, max=featmap_size[0])
+ x2 = x2.clamp(min=0, max=featmap_size[1])
+ y2 = y2.clamp(min=0, max=featmap_size[0])
+ return (x1, y1, x2, y2)
+
+
+def anchor_ctr_inside_region_flags(anchors, stride, region):
+ """Get the flag indicate whether anchor centers are inside regions."""
+ x1, y1, x2, y2 = region
+ f_anchors = anchors / stride
+ x = (f_anchors[:, 0] + f_anchors[:, 2]) * 0.5
+ y = (f_anchors[:, 1] + f_anchors[:, 3]) * 0.5
+ flags = (x >= x1) & (x <= x2) & (y >= y1) & (y <= y2)
+ return flags
+
+
+@BBOX_ASSIGNERS.register_module()
+class RegionAssigner(BaseAssigner):
+ """Assign a corresponding gt bbox or background to each bbox.
+
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
+ indicating the ground truth index.
+
+ - -1: don't care
+ - 0: negative sample, no assigned gt
+ - positive integer: positive sample, index (1-based) of assigned gt
+
+ Args:
+ center_ratio: ratio of the region in the center of the bbox to
+ define positive sample.
+ ignore_ratio: ratio of the region to define ignore samples.
+ """
+
+ def __init__(self, center_ratio=0.2, ignore_ratio=0.5):
+ self.center_ratio = center_ratio
+ self.ignore_ratio = ignore_ratio
+
+ def assign(self,
+ mlvl_anchors,
+ mlvl_valid_flags,
+ gt_bboxes,
+ img_meta,
+ featmap_sizes,
+ anchor_scale,
+ anchor_strides,
+ gt_bboxes_ignore=None,
+ gt_labels=None,
+ allowed_border=0):
+ """Assign gt to anchors.
+
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
+ 0 means negative sample, positive number is the index (1-based) of
+ assigned gt.
+ The assignment is done in following steps, the order matters.
+
+ 1. Assign every anchor to 0 (negative)
+ For each gt_bboxes:
+ 2. Compute ignore flags based on ignore_region then
+ assign -1 to anchors w.r.t. ignore flags
+ 3. Compute pos flags based on center_region then
+ assign gt_bboxes to anchors w.r.t. pos flags
+ 4. Compute ignore flags based on adjacent anchor lvl then
+ assign -1 to anchors w.r.t. ignore flags
+ 5. Assign anchor outside of image to -1
+
+ Args:
+ mlvl_anchors (list[Tensor]): Multi level anchors.
+ mlvl_valid_flags (list[Tensor]): Multi level valid flags.
+ gt_bboxes (Tensor): Ground truth bboxes of image
+ img_meta (dict): Meta info of image.
+ featmap_sizes (list[Tensor]): Feature mapsize each level
+ anchor_scale (int): Scale of the anchor.
+ anchor_strides (list[int]): Stride of the anchor.
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
+ labelled as `ignored`, e.g., crowd boxes in COCO.
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
+ allowed_border (int, optional): The border to allow the valid
+ anchor. Defaults to 0.
+
+ Returns:
+ :obj:`AssignResult`: The assign result.
+ """
+ if gt_bboxes_ignore is not None:
+ raise NotImplementedError
+
+ num_gts = gt_bboxes.shape[0]
+ num_bboxes = sum(x.shape[0] for x in mlvl_anchors)
+
+ if num_gts == 0 or num_bboxes == 0:
+ # No ground truth or boxes, return empty assignment
+ max_overlaps = gt_bboxes.new_zeros((num_bboxes, ))
+ assigned_gt_inds = gt_bboxes.new_zeros((num_bboxes, ),
+ dtype=torch.long)
+ if gt_labels is None:
+ assigned_labels = None
+ else:
+ assigned_labels = gt_bboxes.new_full((num_bboxes, ),
+ -1,
+ dtype=torch.long)
+ return AssignResult(
+ num_gts,
+ assigned_gt_inds,
+ max_overlaps,
+ labels=assigned_labels)
+
+ num_lvls = len(mlvl_anchors)
+ r1 = (1 - self.center_ratio) / 2
+ r2 = (1 - self.ignore_ratio) / 2
+
+ scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+ min_anchor_size = scale.new_full(
+ (1, ), float(anchor_scale * anchor_strides[0]))
+ target_lvls = torch.floor(
+ torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
+ target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
+
+ # 1. assign 0 (negative) by default
+ mlvl_assigned_gt_inds = []
+ mlvl_ignore_flags = []
+ for lvl in range(num_lvls):
+ h, w = featmap_sizes[lvl]
+ assert h * w == mlvl_anchors[lvl].shape[0]
+ assigned_gt_inds = gt_bboxes.new_full((h * w, ),
+ 0,
+ dtype=torch.long)
+ ignore_flags = torch.zeros_like(assigned_gt_inds)
+ mlvl_assigned_gt_inds.append(assigned_gt_inds)
+ mlvl_ignore_flags.append(ignore_flags)
+
+ for gt_id in range(num_gts):
+ lvl = target_lvls[gt_id].item()
+ featmap_size = featmap_sizes[lvl]
+ stride = anchor_strides[lvl]
+ anchors = mlvl_anchors[lvl]
+ gt_bbox = gt_bboxes[gt_id, :4]
+
+ # Compute regions
+ ignore_region = calc_region(gt_bbox, r2, stride, featmap_size)
+ ctr_region = calc_region(gt_bbox, r1, stride, featmap_size)
+
+ # 2. Assign -1 to ignore flags
+ ignore_flags = anchor_ctr_inside_region_flags(
+ anchors, stride, ignore_region)
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
+
+ # 3. Assign gt_bboxes to pos flags
+ pos_flags = anchor_ctr_inside_region_flags(anchors, stride,
+ ctr_region)
+ mlvl_assigned_gt_inds[lvl][pos_flags] = gt_id + 1
+
+ # 4. Assign -1 to ignore adjacent lvl
+ if lvl > 0:
+ d_lvl = lvl - 1
+ d_anchors = mlvl_anchors[d_lvl]
+ d_featmap_size = featmap_sizes[d_lvl]
+ d_stride = anchor_strides[d_lvl]
+ d_ignore_region = calc_region(gt_bbox, r2, d_stride,
+ d_featmap_size)
+ ignore_flags = anchor_ctr_inside_region_flags(
+ d_anchors, d_stride, d_ignore_region)
+ mlvl_ignore_flags[d_lvl][ignore_flags] = 1
+ if lvl < num_lvls - 1:
+ u_lvl = lvl + 1
+ u_anchors = mlvl_anchors[u_lvl]
+ u_featmap_size = featmap_sizes[u_lvl]
+ u_stride = anchor_strides[u_lvl]
+ u_ignore_region = calc_region(gt_bbox, r2, u_stride,
+ u_featmap_size)
+ ignore_flags = anchor_ctr_inside_region_flags(
+ u_anchors, u_stride, u_ignore_region)
+ mlvl_ignore_flags[u_lvl][ignore_flags] = 1
+
+ # 4. (cont.) Assign -1 to ignore adjacent lvl
+ for lvl in range(num_lvls):
+ ignore_flags = mlvl_ignore_flags[lvl]
+ mlvl_assigned_gt_inds[lvl][ignore_flags] = -1
+
+ # 5. Assign -1 to anchor outside of image
+ flat_assigned_gt_inds = torch.cat(mlvl_assigned_gt_inds)
+ flat_anchors = torch.cat(mlvl_anchors)
+ flat_valid_flags = torch.cat(mlvl_valid_flags)
+ assert (flat_assigned_gt_inds.shape[0] == flat_anchors.shape[0] ==
+ flat_valid_flags.shape[0])
+ inside_flags = anchor_inside_flags(flat_anchors, flat_valid_flags,
+ img_meta['img_shape'],
+ allowed_border)
+ outside_flags = ~inside_flags
+ flat_assigned_gt_inds[outside_flags] = -1
+
+ if gt_labels is not None:
+ assigned_labels = torch.zeros_like(flat_assigned_gt_inds)
+ pos_flags = assigned_gt_inds > 0
+ assigned_labels[pos_flags] = gt_labels[
+ flat_assigned_gt_inds[pos_flags] - 1]
+ else:
+ assigned_labels = None
+
+ return AssignResult(
+ num_gts, flat_assigned_gt_inds, None, labels=assigned_labels)
diff --git a/mmdet/core/bbox/builder.py b/mmdet/core/bbox/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..682683b62ae55396f24e9f9eea0f8193e2e88de6
--- /dev/null
+++ b/mmdet/core/bbox/builder.py
@@ -0,0 +1,20 @@
+from mmcv.utils import Registry, build_from_cfg
+
+BBOX_ASSIGNERS = Registry('bbox_assigner')
+BBOX_SAMPLERS = Registry('bbox_sampler')
+BBOX_CODERS = Registry('bbox_coder')
+
+
+def build_assigner(cfg, **default_args):
+ """Builder of box assigner."""
+ return build_from_cfg(cfg, BBOX_ASSIGNERS, default_args)
+
+
+def build_sampler(cfg, **default_args):
+ """Builder of box sampler."""
+ return build_from_cfg(cfg, BBOX_SAMPLERS, default_args)
+
+
+def build_bbox_coder(cfg, **default_args):
+ """Builder of box coder."""
+ return build_from_cfg(cfg, BBOX_CODERS, default_args)
diff --git a/mmdet/core/bbox/coder/__init__.py b/mmdet/core/bbox/coder/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae455ba8fc0e0727e2d581cdc8f20fceededf99a
--- /dev/null
+++ b/mmdet/core/bbox/coder/__init__.py
@@ -0,0 +1,13 @@
+from .base_bbox_coder import BaseBBoxCoder
+from .bucketing_bbox_coder import BucketingBBoxCoder
+from .delta_xywh_bbox_coder import DeltaXYWHBBoxCoder
+from .legacy_delta_xywh_bbox_coder import LegacyDeltaXYWHBBoxCoder
+from .pseudo_bbox_coder import PseudoBBoxCoder
+from .tblr_bbox_coder import TBLRBBoxCoder
+from .yolo_bbox_coder import YOLOBBoxCoder
+
+__all__ = [
+ 'BaseBBoxCoder', 'PseudoBBoxCoder', 'DeltaXYWHBBoxCoder',
+ 'LegacyDeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'YOLOBBoxCoder',
+ 'BucketingBBoxCoder'
+]
diff --git a/mmdet/core/bbox/coder/base_bbox_coder.py b/mmdet/core/bbox/coder/base_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf0b34c7cc2fe561718b0c884990beb40a993643
--- /dev/null
+++ b/mmdet/core/bbox/coder/base_bbox_coder.py
@@ -0,0 +1,17 @@
+from abc import ABCMeta, abstractmethod
+
+
+class BaseBBoxCoder(metaclass=ABCMeta):
+ """Base bounding box coder."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ @abstractmethod
+ def encode(self, bboxes, gt_bboxes):
+ """Encode deltas between bboxes and ground truth boxes."""
+
+ @abstractmethod
+ def decode(self, bboxes, bboxes_pred):
+ """Decode the predicted bboxes according to prediction and base
+ boxes."""
diff --git a/mmdet/core/bbox/coder/bucketing_bbox_coder.py b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..92d24b4519edece7a4af8f5cfa9af025b25f2dad
--- /dev/null
+++ b/mmdet/core/bbox/coder/bucketing_bbox_coder.py
@@ -0,0 +1,350 @@
+import mmcv
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from ..builder import BBOX_CODERS
+from ..transforms import bbox_rescale
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class BucketingBBoxCoder(BaseBBoxCoder):
+ """Bucketing BBox Coder for Side-Aware Boundary Localization (SABL).
+
+ Boundary Localization with Bucketing and Bucketing Guided Rescoring
+ are implemented here.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ num_buckets (int): Number of buckets.
+ scale_factor (int): Scale factor of proposals to generate buckets.
+ offset_topk (int): Topk buckets are used to generate
+ bucket fine regression targets. Defaults to 2.
+ offset_upperbound (float): Offset upperbound to generate
+ bucket fine regression targets.
+ To avoid too large offset displacements. Defaults to 1.0.
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
+ Defaults to True.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ num_buckets,
+ scale_factor,
+ offset_topk=2,
+ offset_upperbound=1.0,
+ cls_ignore_neighbor=True,
+ clip_border=True):
+ super(BucketingBBoxCoder, self).__init__()
+ self.num_buckets = num_buckets
+ self.scale_factor = scale_factor
+ self.offset_topk = offset_topk
+ self.offset_upperbound = offset_upperbound
+ self.cls_ignore_neighbor = cls_ignore_neighbor
+ self.clip_border = clip_border
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get bucketing estimation and fine regression targets during
+ training.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground truth boxes.
+
+ Returns:
+ encoded_bboxes(tuple[Tensor]): bucketing estimation
+ and fine regression targets and weights
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bbox2bucket(bboxes, gt_bboxes, self.num_buckets,
+ self.scale_factor, self.offset_topk,
+ self.offset_upperbound,
+ self.cls_ignore_neighbor)
+ return encoded_bboxes
+
+ def decode(self, bboxes, pred_bboxes, max_shape=None):
+ """Apply transformation `pred_bboxes` to `boxes`.
+ Args:
+ boxes (torch.Tensor): Basic boxes.
+ pred_bboxes (torch.Tensor): Predictions for bucketing estimation
+ and fine regression
+ max_shape (tuple[int], optional): Maximum shape of boxes.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert len(pred_bboxes) == 2
+ cls_preds, offset_preds = pred_bboxes
+ assert cls_preds.size(0) == bboxes.size(0) and offset_preds.size(
+ 0) == bboxes.size(0)
+ decoded_bboxes = bucket2bbox(bboxes, cls_preds, offset_preds,
+ self.num_buckets, self.scale_factor,
+ max_shape, self.clip_border)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def generat_buckets(proposals, num_buckets, scale_factor=1.0):
+ """Generate buckets w.r.t bucket number and scale factor of proposals.
+
+ Args:
+ proposals (Tensor): Shape (n, 4)
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+
+ Returns:
+ tuple[Tensor]: (bucket_w, bucket_h, l_buckets, r_buckets,
+ t_buckets, d_buckets)
+
+ - bucket_w: Width of buckets on x-axis. Shape (n, ).
+ - bucket_h: Height of buckets on y-axis. Shape (n, ).
+ - l_buckets: Left buckets. Shape (n, ceil(side_num/2)).
+ - r_buckets: Right buckets. Shape (n, ceil(side_num/2)).
+ - t_buckets: Top buckets. Shape (n, ceil(side_num/2)).
+ - d_buckets: Down buckets. Shape (n, ceil(side_num/2)).
+ """
+ proposals = bbox_rescale(proposals, scale_factor)
+
+ # number of buckets in each side
+ side_num = int(np.ceil(num_buckets / 2.0))
+ pw = proposals[..., 2] - proposals[..., 0]
+ ph = proposals[..., 3] - proposals[..., 1]
+ px1 = proposals[..., 0]
+ py1 = proposals[..., 1]
+ px2 = proposals[..., 2]
+ py2 = proposals[..., 3]
+
+ bucket_w = pw / num_buckets
+ bucket_h = ph / num_buckets
+
+ # left buckets
+ l_buckets = px1[:, None] + (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
+ # right buckets
+ r_buckets = px2[:, None] - (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_w[:, None]
+ # top buckets
+ t_buckets = py1[:, None] + (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
+ # down buckets
+ d_buckets = py2[:, None] - (0.5 + torch.arange(
+ 0, side_num).to(proposals).float())[None, :] * bucket_h[:, None]
+ return bucket_w, bucket_h, l_buckets, r_buckets, t_buckets, d_buckets
+
+
+@mmcv.jit(coderize=True)
+def bbox2bucket(proposals,
+ gt,
+ num_buckets,
+ scale_factor,
+ offset_topk=2,
+ offset_upperbound=1.0,
+ cls_ignore_neighbor=True):
+ """Generate buckets estimation and fine regression targets.
+
+ Args:
+ proposals (Tensor): Shape (n, 4)
+ gt (Tensor): Shape (n, 4)
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+ offset_topk (int): Topk buckets are used to generate
+ bucket fine regression targets. Defaults to 2.
+ offset_upperbound (float): Offset allowance to generate
+ bucket fine regression targets.
+ To avoid too large offset displacements. Defaults to 1.0.
+ cls_ignore_neighbor (bool): Ignore second nearest bucket or Not.
+ Defaults to True.
+
+ Returns:
+ tuple[Tensor]: (offsets, offsets_weights, bucket_labels, cls_weights).
+
+ - offsets: Fine regression targets. \
+ Shape (n, num_buckets*2).
+ - offsets_weights: Fine regression weights. \
+ Shape (n, num_buckets*2).
+ - bucket_labels: Bucketing estimation labels. \
+ Shape (n, num_buckets*2).
+ - cls_weights: Bucketing estimation weights. \
+ Shape (n, num_buckets*2).
+ """
+ assert proposals.size() == gt.size()
+
+ # generate buckets
+ proposals = proposals.float()
+ gt = gt.float()
+ (bucket_w, bucket_h, l_buckets, r_buckets, t_buckets,
+ d_buckets) = generat_buckets(proposals, num_buckets, scale_factor)
+
+ gx1 = gt[..., 0]
+ gy1 = gt[..., 1]
+ gx2 = gt[..., 2]
+ gy2 = gt[..., 3]
+
+ # generate offset targets and weights
+ # offsets from buckets to gts
+ l_offsets = (l_buckets - gx1[:, None]) / bucket_w[:, None]
+ r_offsets = (r_buckets - gx2[:, None]) / bucket_w[:, None]
+ t_offsets = (t_buckets - gy1[:, None]) / bucket_h[:, None]
+ d_offsets = (d_buckets - gy2[:, None]) / bucket_h[:, None]
+
+ # select top-k nearset buckets
+ l_topk, l_label = l_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ r_topk, r_label = r_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ t_topk, t_label = t_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+ d_topk, d_label = d_offsets.abs().topk(
+ offset_topk, dim=1, largest=False, sorted=True)
+
+ offset_l_weights = l_offsets.new_zeros(l_offsets.size())
+ offset_r_weights = r_offsets.new_zeros(r_offsets.size())
+ offset_t_weights = t_offsets.new_zeros(t_offsets.size())
+ offset_d_weights = d_offsets.new_zeros(d_offsets.size())
+ inds = torch.arange(0, proposals.size(0)).to(proposals).long()
+
+ # generate offset weights of top-k nearset buckets
+ for k in range(offset_topk):
+ if k >= 1:
+ offset_l_weights[inds, l_label[:,
+ k]] = (l_topk[:, k] <
+ offset_upperbound).float()
+ offset_r_weights[inds, r_label[:,
+ k]] = (r_topk[:, k] <
+ offset_upperbound).float()
+ offset_t_weights[inds, t_label[:,
+ k]] = (t_topk[:, k] <
+ offset_upperbound).float()
+ offset_d_weights[inds, d_label[:,
+ k]] = (d_topk[:, k] <
+ offset_upperbound).float()
+ else:
+ offset_l_weights[inds, l_label[:, k]] = 1.0
+ offset_r_weights[inds, r_label[:, k]] = 1.0
+ offset_t_weights[inds, t_label[:, k]] = 1.0
+ offset_d_weights[inds, d_label[:, k]] = 1.0
+
+ offsets = torch.cat([l_offsets, r_offsets, t_offsets, d_offsets], dim=-1)
+ offsets_weights = torch.cat([
+ offset_l_weights, offset_r_weights, offset_t_weights, offset_d_weights
+ ],
+ dim=-1)
+
+ # generate bucket labels and weight
+ side_num = int(np.ceil(num_buckets / 2.0))
+ labels = torch.stack(
+ [l_label[:, 0], r_label[:, 0], t_label[:, 0], d_label[:, 0]], dim=-1)
+
+ batch_size = labels.size(0)
+ bucket_labels = F.one_hot(labels.view(-1), side_num).view(batch_size,
+ -1).float()
+ bucket_cls_l_weights = (l_offsets.abs() < 1).float()
+ bucket_cls_r_weights = (r_offsets.abs() < 1).float()
+ bucket_cls_t_weights = (t_offsets.abs() < 1).float()
+ bucket_cls_d_weights = (d_offsets.abs() < 1).float()
+ bucket_cls_weights = torch.cat([
+ bucket_cls_l_weights, bucket_cls_r_weights, bucket_cls_t_weights,
+ bucket_cls_d_weights
+ ],
+ dim=-1)
+ # ignore second nearest buckets for cls if necessary
+ if cls_ignore_neighbor:
+ bucket_cls_weights = (~((bucket_cls_weights == 1) &
+ (bucket_labels == 0))).float()
+ else:
+ bucket_cls_weights[:] = 1.0
+ return offsets, offsets_weights, bucket_labels, bucket_cls_weights
+
+
+@mmcv.jit(coderize=True)
+def bucket2bbox(proposals,
+ cls_preds,
+ offset_preds,
+ num_buckets,
+ scale_factor=1.0,
+ max_shape=None,
+ clip_border=True):
+ """Apply bucketing estimation (cls preds) and fine regression (offset
+ preds) to generate det bboxes.
+
+ Args:
+ proposals (Tensor): Boxes to be transformed. Shape (n, 4)
+ cls_preds (Tensor): bucketing estimation. Shape (n, num_buckets*2).
+ offset_preds (Tensor): fine regression. Shape (n, num_buckets*2).
+ num_buckets (int): Number of buckets.
+ scale_factor (float): Scale factor to rescale proposals.
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+
+ Returns:
+ tuple[Tensor]: (bboxes, loc_confidence).
+
+ - bboxes: predicted bboxes. Shape (n, 4)
+ - loc_confidence: localization confidence of predicted bboxes.
+ Shape (n,).
+ """
+
+ side_num = int(np.ceil(num_buckets / 2.0))
+ cls_preds = cls_preds.view(-1, side_num)
+ offset_preds = offset_preds.view(-1, side_num)
+
+ scores = F.softmax(cls_preds, dim=1)
+ score_topk, score_label = scores.topk(2, dim=1, largest=True, sorted=True)
+
+ rescaled_proposals = bbox_rescale(proposals, scale_factor)
+
+ pw = rescaled_proposals[..., 2] - rescaled_proposals[..., 0]
+ ph = rescaled_proposals[..., 3] - rescaled_proposals[..., 1]
+ px1 = rescaled_proposals[..., 0]
+ py1 = rescaled_proposals[..., 1]
+ px2 = rescaled_proposals[..., 2]
+ py2 = rescaled_proposals[..., 3]
+
+ bucket_w = pw / num_buckets
+ bucket_h = ph / num_buckets
+
+ score_inds_l = score_label[0::4, 0]
+ score_inds_r = score_label[1::4, 0]
+ score_inds_t = score_label[2::4, 0]
+ score_inds_d = score_label[3::4, 0]
+ l_buckets = px1 + (0.5 + score_inds_l.float()) * bucket_w
+ r_buckets = px2 - (0.5 + score_inds_r.float()) * bucket_w
+ t_buckets = py1 + (0.5 + score_inds_t.float()) * bucket_h
+ d_buckets = py2 - (0.5 + score_inds_d.float()) * bucket_h
+
+ offsets = offset_preds.view(-1, 4, side_num)
+ inds = torch.arange(proposals.size(0)).to(proposals).long()
+ l_offsets = offsets[:, 0, :][inds, score_inds_l]
+ r_offsets = offsets[:, 1, :][inds, score_inds_r]
+ t_offsets = offsets[:, 2, :][inds, score_inds_t]
+ d_offsets = offsets[:, 3, :][inds, score_inds_d]
+
+ x1 = l_buckets - l_offsets * bucket_w
+ x2 = r_buckets - r_offsets * bucket_w
+ y1 = t_buckets - t_offsets * bucket_h
+ y2 = d_buckets - d_offsets * bucket_h
+
+ if clip_border and max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+ bboxes = torch.cat([x1[:, None], y1[:, None], x2[:, None], y2[:, None]],
+ dim=-1)
+
+ # bucketing guided rescoring
+ loc_confidence = score_topk[:, 0]
+ top2_neighbor_inds = (score_label[:, 0] - score_label[:, 1]).abs() == 1
+ loc_confidence += score_topk[:, 1] * top2_neighbor_inds.float()
+ loc_confidence = loc_confidence.view(-1, 4).mean(dim=1)
+
+ return bboxes, loc_confidence
diff --git a/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..da317184a6eb6f87b0b658e9ff8be289794a0cb2
--- /dev/null
+++ b/mmdet/core/bbox/coder/delta_xywh_bbox_coder.py
@@ -0,0 +1,237 @@
+import mmcv
+import numpy as np
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class DeltaXYWHBBoxCoder(BaseBBoxCoder):
+ """Delta XYWH BBox coder.
+
+ Following the practice in `R-CNN `_,
+ this coder encodes bbox (x1, y1, x2, y2) into delta (dx, dy, dw, dh) and
+ decodes delta (dx, dy, dw, dh) back to original bbox (x1, y1, x2, y2).
+
+ Args:
+ target_means (Sequence[float]): Denormalizing means of target for
+ delta coordinates
+ target_stds (Sequence[float]): Denormalizing standard deviation of
+ target for delta coordinates
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ target_means=(0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1.),
+ clip_border=True):
+ super(BaseBBoxCoder, self).__init__()
+ self.means = target_means
+ self.stds = target_stds
+ self.clip_border = clip_border
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): Source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
+ ground-truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bbox2delta(bboxes, gt_bboxes, self.means, self.stds)
+ return encoded_bboxes
+
+ def decode(self,
+ bboxes,
+ pred_bboxes,
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ bboxes (torch.Tensor): Basic boxes. Shape (B, N, 4) or (N, 4)
+ pred_bboxes (Tensor): Encoded offsets with respect to each roi.
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+ when rois is a grid of anchors.Offset encoding follows [1]_.
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ wh_ratio_clip (float, optional): The allowed ratio between
+ width and height.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+
+ assert pred_bboxes.size(0) == bboxes.size(0)
+ if pred_bboxes.ndim == 3:
+ assert pred_bboxes.size(1) == bboxes.size(1)
+ decoded_bboxes = delta2bbox(bboxes, pred_bboxes, self.means, self.stds,
+ max_shape, wh_ratio_clip, self.clip_border)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def bbox2delta(proposals, gt, means=(0., 0., 0., 0.), stds=(1., 1., 1., 1.)):
+ """Compute deltas of proposals w.r.t. gt.
+
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
+ truth bboxes to get regression target.
+ This is the inverse function of :func:`delta2bbox`.
+
+ Args:
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+
+ Returns:
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
+ dw, dh.
+ """
+ assert proposals.size() == gt.size()
+
+ proposals = proposals.float()
+ gt = gt.float()
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
+ pw = proposals[..., 2] - proposals[..., 0]
+ ph = proposals[..., 3] - proposals[..., 1]
+
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
+ gw = gt[..., 2] - gt[..., 0]
+ gh = gt[..., 3] - gt[..., 1]
+
+ dx = (gx - px) / pw
+ dy = (gy - py) / ph
+ dw = torch.log(gw / pw)
+ dh = torch.log(gh / ph)
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+ means = deltas.new_tensor(means).unsqueeze(0)
+ stds = deltas.new_tensor(stds).unsqueeze(0)
+ deltas = deltas.sub_(means).div_(stds)
+
+ return deltas
+
+
+@mmcv.jit(coderize=True)
+def delta2bbox(rois,
+ deltas,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.),
+ max_shape=None,
+ wh_ratio_clip=16 / 1000,
+ clip_border=True):
+ """Apply deltas to shift/scale base boxes.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of :func:`bbox2delta`.
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4) or (B, N, 4)
+ deltas (Tensor): Encoded offsets with respect to each roi.
+ Has shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4). Note N = num_anchors * W * H
+ when rois is a grid of anchors.Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If rois shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+
+ Returns:
+ Tensor: Boxes with shape (B, N, num_classes * 4) or (B, N, 4) or
+ (N, num_classes * 4) or (N, 4), where 4 represent
+ tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Example:
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 5., 5., 5., 5.]])
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
+ >>> [ 1., 1., 1., 1.],
+ >>> [ 0., 0., 2., -1.],
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
+ >>> delta2bbox(rois, deltas, max_shape=(32, 32, 3))
+ tensor([[0.0000, 0.0000, 1.0000, 1.0000],
+ [0.1409, 0.1409, 2.8591, 2.8591],
+ [0.0000, 0.3161, 4.1945, 0.6839],
+ [5.0000, 5.0000, 5.0000, 5.0000]])
+ """
+ means = deltas.new_tensor(means).view(1,
+ -1).repeat(1,
+ deltas.size(-1) // 4)
+ stds = deltas.new_tensor(stds).view(1, -1).repeat(1, deltas.size(-1) // 4)
+ denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[..., 0::4]
+ dy = denorm_deltas[..., 1::4]
+ dw = denorm_deltas[..., 2::4]
+ dh = denorm_deltas[..., 3::4]
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+ x1, y1 = rois[..., 0], rois[..., 1]
+ x2, y2 = rois[..., 2], rois[..., 3]
+ # Compute center of each roi
+ px = ((x1 + x2) * 0.5).unsqueeze(-1).expand_as(dx)
+ py = ((y1 + y2) * 0.5).unsqueeze(-1).expand_as(dy)
+ # Compute width/height of each roi
+ pw = (x2 - x1).unsqueeze(-1).expand_as(dw)
+ ph = (y2 - y1).unsqueeze(-1).expand_as(dh)
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+ # Use network energy to shift the center of each roi
+ gx = px + pw * dx
+ gy = py + ph * dy
+ # Convert center-xy/width/height to top-left, bottom-right
+ x1 = gx - gw * 0.5
+ y1 = gy - gh * 0.5
+ x2 = gx + gw * 0.5
+ y2 = gy + gh * 0.5
+
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view(deltas.size())
+
+ if clip_border and max_shape is not None:
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = x1.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(x1)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = x1.new_tensor(0)
+ max_xy = torch.cat(
+ [max_shape] * (deltas.size(-1) // 2),
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
diff --git a/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py b/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..190309fd42a1b76c12c82fc1acf0511494be5ac3
--- /dev/null
+++ b/mmdet/core/bbox/coder/legacy_delta_xywh_bbox_coder.py
@@ -0,0 +1,215 @@
+import mmcv
+import numpy as np
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class LegacyDeltaXYWHBBoxCoder(BaseBBoxCoder):
+ """Legacy Delta XYWH BBox coder used in MMDet V1.x.
+
+ Following the practice in R-CNN [1]_, this coder encodes bbox (x1, y1, x2,
+ y2) into delta (dx, dy, dw, dh) and decodes delta (dx, dy, dw, dh)
+ back to original bbox (x1, y1, x2, y2).
+
+ Note:
+ The main difference between :class`LegacyDeltaXYWHBBoxCoder` and
+ :class:`DeltaXYWHBBoxCoder` is whether ``+ 1`` is used during width and
+ height calculation. We suggest to only use this coder when testing with
+ MMDet V1.x models.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Args:
+ target_means (Sequence[float]): denormalizing means of target for
+ delta coordinates
+ target_stds (Sequence[float]): denormalizing standard deviation of
+ target for delta coordinates
+ """
+
+ def __init__(self,
+ target_means=(0., 0., 0., 0.),
+ target_stds=(1., 1., 1., 1.)):
+ super(BaseBBoxCoder, self).__init__()
+ self.means = target_means
+ self.stds = target_stds
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground-truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = legacy_bbox2delta(bboxes, gt_bboxes, self.means,
+ self.stds)
+ return encoded_bboxes
+
+ def decode(self,
+ bboxes,
+ pred_bboxes,
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ boxes (torch.Tensor): Basic boxes.
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ max_shape (tuple[int], optional): Maximum shape of boxes.
+ Defaults to None.
+ wh_ratio_clip (float, optional): The allowed ratio between
+ width and height.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert pred_bboxes.size(0) == bboxes.size(0)
+ decoded_bboxes = legacy_delta2bbox(bboxes, pred_bboxes, self.means,
+ self.stds, max_shape, wh_ratio_clip)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def legacy_bbox2delta(proposals,
+ gt,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.)):
+ """Compute deltas of proposals w.r.t. gt in the MMDet V1.x manner.
+
+ We usually compute the deltas of x, y, w, h of proposals w.r.t ground
+ truth bboxes to get regression target.
+ This is the inverse function of `delta2bbox()`
+
+ Args:
+ proposals (Tensor): Boxes to be transformed, shape (N, ..., 4)
+ gt (Tensor): Gt bboxes to be used as base, shape (N, ..., 4)
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+
+ Returns:
+ Tensor: deltas with shape (N, 4), where columns represent dx, dy,
+ dw, dh.
+ """
+ assert proposals.size() == gt.size()
+
+ proposals = proposals.float()
+ gt = gt.float()
+ px = (proposals[..., 0] + proposals[..., 2]) * 0.5
+ py = (proposals[..., 1] + proposals[..., 3]) * 0.5
+ pw = proposals[..., 2] - proposals[..., 0] + 1.0
+ ph = proposals[..., 3] - proposals[..., 1] + 1.0
+
+ gx = (gt[..., 0] + gt[..., 2]) * 0.5
+ gy = (gt[..., 1] + gt[..., 3]) * 0.5
+ gw = gt[..., 2] - gt[..., 0] + 1.0
+ gh = gt[..., 3] - gt[..., 1] + 1.0
+
+ dx = (gx - px) / pw
+ dy = (gy - py) / ph
+ dw = torch.log(gw / pw)
+ dh = torch.log(gh / ph)
+ deltas = torch.stack([dx, dy, dw, dh], dim=-1)
+
+ means = deltas.new_tensor(means).unsqueeze(0)
+ stds = deltas.new_tensor(stds).unsqueeze(0)
+ deltas = deltas.sub_(means).div_(stds)
+
+ return deltas
+
+
+@mmcv.jit(coderize=True)
+def legacy_delta2bbox(rois,
+ deltas,
+ means=(0., 0., 0., 0.),
+ stds=(1., 1., 1., 1.),
+ max_shape=None,
+ wh_ratio_clip=16 / 1000):
+ """Apply deltas to shift/scale base boxes in the MMDet V1.x manner.
+
+ Typically the rois are anchor or proposed bounding boxes and the deltas are
+ network outputs used to shift/scale those boxes.
+ This is the inverse function of `bbox2delta()`
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (N, 4)
+ deltas (Tensor): Encoded offsets with respect to each roi.
+ Has shape (N, 4 * num_classes). Note N = num_anchors * W * H when
+ rois is a grid of anchors. Offset encoding follows [1]_.
+ means (Sequence[float]): Denormalizing means for delta coordinates
+ stds (Sequence[float]): Denormalizing standard deviation for delta
+ coordinates
+ max_shape (tuple[int, int]): Maximum bounds for boxes. specifies (H, W)
+ wh_ratio_clip (float): Maximum aspect ratio for boxes.
+
+ Returns:
+ Tensor: Boxes with shape (N, 4), where columns represent
+ tl_x, tl_y, br_x, br_y.
+
+ References:
+ .. [1] https://arxiv.org/abs/1311.2524
+
+ Example:
+ >>> rois = torch.Tensor([[ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 0., 0., 1., 1.],
+ >>> [ 5., 5., 5., 5.]])
+ >>> deltas = torch.Tensor([[ 0., 0., 0., 0.],
+ >>> [ 1., 1., 1., 1.],
+ >>> [ 0., 0., 2., -1.],
+ >>> [ 0.7, -1.9, -0.5, 0.3]])
+ >>> legacy_delta2bbox(rois, deltas, max_shape=(32, 32))
+ tensor([[0.0000, 0.0000, 1.5000, 1.5000],
+ [0.0000, 0.0000, 5.2183, 5.2183],
+ [0.0000, 0.1321, 7.8891, 0.8679],
+ [5.3967, 2.4251, 6.0033, 3.7749]])
+ """
+ means = deltas.new_tensor(means).repeat(1, deltas.size(1) // 4)
+ stds = deltas.new_tensor(stds).repeat(1, deltas.size(1) // 4)
+ denorm_deltas = deltas * stds + means
+ dx = denorm_deltas[:, 0::4]
+ dy = denorm_deltas[:, 1::4]
+ dw = denorm_deltas[:, 2::4]
+ dh = denorm_deltas[:, 3::4]
+ max_ratio = np.abs(np.log(wh_ratio_clip))
+ dw = dw.clamp(min=-max_ratio, max=max_ratio)
+ dh = dh.clamp(min=-max_ratio, max=max_ratio)
+ # Compute center of each roi
+ px = ((rois[:, 0] + rois[:, 2]) * 0.5).unsqueeze(1).expand_as(dx)
+ py = ((rois[:, 1] + rois[:, 3]) * 0.5).unsqueeze(1).expand_as(dy)
+ # Compute width/height of each roi
+ pw = (rois[:, 2] - rois[:, 0] + 1.0).unsqueeze(1).expand_as(dw)
+ ph = (rois[:, 3] - rois[:, 1] + 1.0).unsqueeze(1).expand_as(dh)
+ # Use exp(network energy) to enlarge/shrink each roi
+ gw = pw * dw.exp()
+ gh = ph * dh.exp()
+ # Use network energy to shift the center of each roi
+ gx = px + pw * dx
+ gy = py + ph * dy
+ # Convert center-xy/width/height to top-left, bottom-right
+
+ # The true legacy box coder should +- 0.5 here.
+ # However, current implementation improves the performance when testing
+ # the models trained in MMDetection 1.X (~0.5 bbox AP, 0.2 mask AP)
+ x1 = gx - gw * 0.5
+ y1 = gy - gh * 0.5
+ x2 = gx + gw * 0.5
+ y2 = gy + gh * 0.5
+ if max_shape is not None:
+ x1 = x1.clamp(min=0, max=max_shape[1] - 1)
+ y1 = y1.clamp(min=0, max=max_shape[0] - 1)
+ x2 = x2.clamp(min=0, max=max_shape[1] - 1)
+ y2 = y2.clamp(min=0, max=max_shape[0] - 1)
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1).view_as(deltas)
+ return bboxes
diff --git a/mmdet/core/bbox/coder/pseudo_bbox_coder.py b/mmdet/core/bbox/coder/pseudo_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c8346f4ae2c7db9719a70c7dc0244e088a9965b
--- /dev/null
+++ b/mmdet/core/bbox/coder/pseudo_bbox_coder.py
@@ -0,0 +1,18 @@
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class PseudoBBoxCoder(BaseBBoxCoder):
+ """Pseudo bounding box coder."""
+
+ def __init__(self, **kwargs):
+ super(BaseBBoxCoder, self).__init__(**kwargs)
+
+ def encode(self, bboxes, gt_bboxes):
+ """torch.Tensor: return the given ``bboxes``"""
+ return gt_bboxes
+
+ def decode(self, bboxes, pred_bboxes):
+ """torch.Tensor: return the given ``pred_bboxes``"""
+ return pred_bboxes
diff --git a/mmdet/core/bbox/coder/tblr_bbox_coder.py b/mmdet/core/bbox/coder/tblr_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..edaffaf1fa252857e1a660ea14a613e2466fb52c
--- /dev/null
+++ b/mmdet/core/bbox/coder/tblr_bbox_coder.py
@@ -0,0 +1,198 @@
+import mmcv
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class TBLRBBoxCoder(BaseBBoxCoder):
+ """TBLR BBox coder.
+
+ Following the practice in `FSAF `_,
+ this coder encodes gt bboxes (x1, y1, x2, y2) into (top, bottom, left,
+ right) and decode it back to the original.
+
+ Args:
+ normalizer (list | float): Normalization factor to be
+ divided with when coding the coordinates. If it is a list, it should
+ have length of 4 indicating normalization factor in tblr dims.
+ Otherwise it is a unified float factor for all dims. Default: 4.0
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+ """
+
+ def __init__(self, normalizer=4.0, clip_border=True):
+ super(BaseBBoxCoder, self).__init__()
+ self.normalizer = normalizer
+ self.clip_border = clip_border
+
+ def encode(self, bboxes, gt_bboxes):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes`` in the (top, left,
+ bottom, right) order.
+
+ Args:
+ bboxes (torch.Tensor): source boxes, e.g., object proposals.
+ gt_bboxes (torch.Tensor): target of the transformation, e.g.,
+ ground truth boxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ encoded_bboxes = bboxes2tblr(
+ bboxes, gt_bboxes, normalizer=self.normalizer)
+ return encoded_bboxes
+
+ def decode(self, bboxes, pred_bboxes, max_shape=None):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ bboxes (torch.Tensor): Basic boxes.Shape (B, N, 4) or (N, 4)
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ (B, N, 4) or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If bboxes shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ decoded_bboxes = tblr2bboxes(
+ bboxes,
+ pred_bboxes,
+ normalizer=self.normalizer,
+ max_shape=max_shape,
+ clip_border=self.clip_border)
+
+ return decoded_bboxes
+
+
+@mmcv.jit(coderize=True)
+def bboxes2tblr(priors, gts, normalizer=4.0, normalize_by_wh=True):
+ """Encode ground truth boxes to tblr coordinate.
+
+ It first convert the gt coordinate to tblr format,
+ (top, bottom, left, right), relative to prior box centers.
+ The tblr coordinate may be normalized by the side length of prior bboxes
+ if `normalize_by_wh` is specified as True, and it is then normalized by
+ the `normalizer` factor.
+
+ Args:
+ priors (Tensor): Prior boxes in point form
+ Shape: (num_proposals,4).
+ gts (Tensor): Coords of ground truth for each prior in point-form
+ Shape: (num_proposals, 4).
+ normalizer (Sequence[float] | float): normalization parameter of
+ encoded boxes. If it is a list, it has to have length = 4.
+ Default: 4.0
+ normalize_by_wh (bool): Whether to normalize tblr coordinate by the
+ side length (wh) of prior bboxes.
+
+ Return:
+ encoded boxes (Tensor), Shape: (num_proposals, 4)
+ """
+
+ # dist b/t match center and prior's center
+ if not isinstance(normalizer, float):
+ normalizer = torch.tensor(normalizer, device=priors.device)
+ assert len(normalizer) == 4, 'Normalizer must have length = 4'
+ assert priors.size(0) == gts.size(0)
+ prior_centers = (priors[:, 0:2] + priors[:, 2:4]) / 2
+ xmin, ymin, xmax, ymax = gts.split(1, dim=1)
+ top = prior_centers[:, 1].unsqueeze(1) - ymin
+ bottom = ymax - prior_centers[:, 1].unsqueeze(1)
+ left = prior_centers[:, 0].unsqueeze(1) - xmin
+ right = xmax - prior_centers[:, 0].unsqueeze(1)
+ loc = torch.cat((top, bottom, left, right), dim=1)
+ if normalize_by_wh:
+ # Normalize tblr by anchor width and height
+ wh = priors[:, 2:4] - priors[:, 0:2]
+ w, h = torch.split(wh, 1, dim=1)
+ loc[:, :2] /= h # tb is normalized by h
+ loc[:, 2:] /= w # lr is normalized by w
+ # Normalize tblr by the given normalization factor
+ return loc / normalizer
+
+
+@mmcv.jit(coderize=True)
+def tblr2bboxes(priors,
+ tblr,
+ normalizer=4.0,
+ normalize_by_wh=True,
+ max_shape=None,
+ clip_border=True):
+ """Decode tblr outputs to prediction boxes.
+
+ The process includes 3 steps: 1) De-normalize tblr coordinates by
+ multiplying it with `normalizer`; 2) De-normalize tblr coordinates by the
+ prior bbox width and height if `normalize_by_wh` is `True`; 3) Convert
+ tblr (top, bottom, left, right) pair relative to the center of priors back
+ to (xmin, ymin, xmax, ymax) coordinate.
+
+ Args:
+ priors (Tensor): Prior boxes in point form (x0, y0, x1, y1)
+ Shape: (N,4) or (B, N, 4).
+ tblr (Tensor): Coords of network output in tblr form
+ Shape: (N, 4) or (B, N, 4).
+ normalizer (Sequence[float] | float): Normalization parameter of
+ encoded boxes. By list, it represents the normalization factors at
+ tblr dims. By float, it is the unified normalization factor at all
+ dims. Default: 4.0
+ normalize_by_wh (bool): Whether the tblr coordinates have been
+ normalized by the side length (wh) of prior bboxes.
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ clip_border (bool, optional): Whether clip the objects outside the
+ border of the image. Defaults to True.
+
+ Return:
+ encoded boxes (Tensor): Boxes with shape (N, 4) or (B, N, 4)
+ """
+ if not isinstance(normalizer, float):
+ normalizer = torch.tensor(normalizer, device=priors.device)
+ assert len(normalizer) == 4, 'Normalizer must have length = 4'
+ assert priors.size(0) == tblr.size(0)
+ if priors.ndim == 3:
+ assert priors.size(1) == tblr.size(1)
+
+ loc_decode = tblr * normalizer
+ prior_centers = (priors[..., 0:2] + priors[..., 2:4]) / 2
+ if normalize_by_wh:
+ wh = priors[..., 2:4] - priors[..., 0:2]
+ w, h = torch.split(wh, 1, dim=-1)
+ # Inplace operation with slice would failed for exporting to ONNX
+ th = h * loc_decode[..., :2] # tb
+ tw = w * loc_decode[..., 2:] # lr
+ loc_decode = torch.cat([th, tw], dim=-1)
+ # Cannot be exported using onnx when loc_decode.split(1, dim=-1)
+ top, bottom, left, right = loc_decode.split((1, 1, 1, 1), dim=-1)
+ xmin = prior_centers[..., 0].unsqueeze(-1) - left
+ xmax = prior_centers[..., 0].unsqueeze(-1) + right
+ ymin = prior_centers[..., 1].unsqueeze(-1) - top
+ ymax = prior_centers[..., 1].unsqueeze(-1) + bottom
+
+ bboxes = torch.cat((xmin, ymin, xmax, ymax), dim=-1)
+
+ if clip_border and max_shape is not None:
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = priors.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(priors)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = priors.new_tensor(0)
+ max_xy = torch.cat([max_shape, max_shape],
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
diff --git a/mmdet/core/bbox/coder/yolo_bbox_coder.py b/mmdet/core/bbox/coder/yolo_bbox_coder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6d0e82ac780820952938d8751ac9776ea31588a
--- /dev/null
+++ b/mmdet/core/bbox/coder/yolo_bbox_coder.py
@@ -0,0 +1,89 @@
+import mmcv
+import torch
+
+from ..builder import BBOX_CODERS
+from .base_bbox_coder import BaseBBoxCoder
+
+
+@BBOX_CODERS.register_module()
+class YOLOBBoxCoder(BaseBBoxCoder):
+ """YOLO BBox coder.
+
+ Following `YOLO `_, this coder divide
+ image into grids, and encode bbox (x1, y1, x2, y2) into (cx, cy, dw, dh).
+ cx, cy in [0., 1.], denotes relative center position w.r.t the center of
+ bboxes. dw, dh are the same as :obj:`DeltaXYWHBBoxCoder`.
+
+ Args:
+ eps (float): Min value of cx, cy when encoding.
+ """
+
+ def __init__(self, eps=1e-6):
+ super(BaseBBoxCoder, self).__init__()
+ self.eps = eps
+
+ @mmcv.jit(coderize=True)
+ def encode(self, bboxes, gt_bboxes, stride):
+ """Get box regression transformation deltas that can be used to
+ transform the ``bboxes`` into the ``gt_bboxes``.
+
+ Args:
+ bboxes (torch.Tensor): Source boxes, e.g., anchors.
+ gt_bboxes (torch.Tensor): Target of the transformation, e.g.,
+ ground-truth boxes.
+ stride (torch.Tensor | int): Stride of bboxes.
+
+ Returns:
+ torch.Tensor: Box transformation deltas
+ """
+
+ assert bboxes.size(0) == gt_bboxes.size(0)
+ assert bboxes.size(-1) == gt_bboxes.size(-1) == 4
+ x_center_gt = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) * 0.5
+ y_center_gt = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) * 0.5
+ w_gt = gt_bboxes[..., 2] - gt_bboxes[..., 0]
+ h_gt = gt_bboxes[..., 3] - gt_bboxes[..., 1]
+ x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
+ y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
+ w = bboxes[..., 2] - bboxes[..., 0]
+ h = bboxes[..., 3] - bboxes[..., 1]
+ w_target = torch.log((w_gt / w).clamp(min=self.eps))
+ h_target = torch.log((h_gt / h).clamp(min=self.eps))
+ x_center_target = ((x_center_gt - x_center) / stride + 0.5).clamp(
+ self.eps, 1 - self.eps)
+ y_center_target = ((y_center_gt - y_center) / stride + 0.5).clamp(
+ self.eps, 1 - self.eps)
+ encoded_bboxes = torch.stack(
+ [x_center_target, y_center_target, w_target, h_target], dim=-1)
+ return encoded_bboxes
+
+ @mmcv.jit(coderize=True)
+ def decode(self, bboxes, pred_bboxes, stride):
+ """Apply transformation `pred_bboxes` to `boxes`.
+
+ Args:
+ boxes (torch.Tensor): Basic boxes, e.g. anchors.
+ pred_bboxes (torch.Tensor): Encoded boxes with shape
+ stride (torch.Tensor | int): Strides of bboxes.
+
+ Returns:
+ torch.Tensor: Decoded boxes.
+ """
+ assert pred_bboxes.size(0) == bboxes.size(0)
+ assert pred_bboxes.size(-1) == bboxes.size(-1) == 4
+ x_center = (bboxes[..., 0] + bboxes[..., 2]) * 0.5
+ y_center = (bboxes[..., 1] + bboxes[..., 3]) * 0.5
+ w = bboxes[..., 2] - bboxes[..., 0]
+ h = bboxes[..., 3] - bboxes[..., 1]
+ # Get outputs x, y
+ x_center_pred = (pred_bboxes[..., 0] - 0.5) * stride + x_center
+ y_center_pred = (pred_bboxes[..., 1] - 0.5) * stride + y_center
+ w_pred = torch.exp(pred_bboxes[..., 2]) * w
+ h_pred = torch.exp(pred_bboxes[..., 3]) * h
+
+ decoded_bboxes = torch.stack(
+ (x_center_pred - w_pred / 2, y_center_pred - h_pred / 2,
+ x_center_pred + w_pred / 2, y_center_pred + h_pred / 2),
+ dim=-1)
+
+ return decoded_bboxes
diff --git a/mmdet/core/bbox/demodata.py b/mmdet/core/bbox/demodata.py
new file mode 100644
index 0000000000000000000000000000000000000000..feecb693745a47d9f2bebd8af9a217ff4f5cc92b
--- /dev/null
+++ b/mmdet/core/bbox/demodata.py
@@ -0,0 +1,41 @@
+import numpy as np
+import torch
+
+from mmdet.utils.util_random import ensure_rng
+
+
+def random_boxes(num=1, scale=1, rng=None):
+ """Simple version of ``kwimage.Boxes.random``
+
+ Returns:
+ Tensor: shape (n, 4) in x1, y1, x2, y2 format.
+
+ References:
+ https://gitlab.kitware.com/computer-vision/kwimage/blob/master/kwimage/structs/boxes.py#L1390
+
+ Example:
+ >>> num = 3
+ >>> scale = 512
+ >>> rng = 0
+ >>> boxes = random_boxes(num, scale, rng)
+ >>> print(boxes)
+ tensor([[280.9925, 278.9802, 308.6148, 366.1769],
+ [216.9113, 330.6978, 224.0446, 456.5878],
+ [405.3632, 196.3221, 493.3953, 270.7942]])
+ """
+ rng = ensure_rng(rng)
+
+ tlbr = rng.rand(num, 4).astype(np.float32)
+
+ tl_x = np.minimum(tlbr[:, 0], tlbr[:, 2])
+ tl_y = np.minimum(tlbr[:, 1], tlbr[:, 3])
+ br_x = np.maximum(tlbr[:, 0], tlbr[:, 2])
+ br_y = np.maximum(tlbr[:, 1], tlbr[:, 3])
+
+ tlbr[:, 0] = tl_x * scale
+ tlbr[:, 1] = tl_y * scale
+ tlbr[:, 2] = br_x * scale
+ tlbr[:, 3] = br_y * scale
+
+ boxes = torch.from_numpy(tlbr)
+ return boxes
diff --git a/mmdet/core/bbox/iou_calculators/__init__.py b/mmdet/core/bbox/iou_calculators/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e71369a58a05fa25e6a754300875fdbb87cb26a5
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/__init__.py
@@ -0,0 +1,4 @@
+from .builder import build_iou_calculator
+from .iou2d_calculator import BboxOverlaps2D, bbox_overlaps
+
+__all__ = ['build_iou_calculator', 'BboxOverlaps2D', 'bbox_overlaps']
diff --git a/mmdet/core/bbox/iou_calculators/builder.py b/mmdet/core/bbox/iou_calculators/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..09094d7ece46a9f18a28ed0960feac2afa9331bb
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/builder.py
@@ -0,0 +1,8 @@
+from mmcv.utils import Registry, build_from_cfg
+
+IOU_CALCULATORS = Registry('IoU calculator')
+
+
+def build_iou_calculator(cfg, default_args=None):
+ """Builder of IoU calculator."""
+ return build_from_cfg(cfg, IOU_CALCULATORS, default_args)
diff --git a/mmdet/core/bbox/iou_calculators/iou2d_calculator.py b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py
new file mode 100644
index 0000000000000000000000000000000000000000..158b702c234f5c10c4f5f03e08e8794ac7b8dcad
--- /dev/null
+++ b/mmdet/core/bbox/iou_calculators/iou2d_calculator.py
@@ -0,0 +1,159 @@
+import torch
+
+from .builder import IOU_CALCULATORS
+
+
+@IOU_CALCULATORS.register_module()
+class BboxOverlaps2D(object):
+ """2D Overlaps (e.g. IoUs, GIoUs) Calculator."""
+
+ def __call__(self, bboxes1, bboxes2, mode='iou', is_aligned=False):
+ """Calculate IoU between 2D bboxes.
+
+ Args:
+ bboxes1 (Tensor): bboxes have shape (m, 4) in
+ format, or shape (m, 5) in format.
+ bboxes2 (Tensor): bboxes have shape (m, 4) in
+ format, shape (m, 5) in format, or be
+ empty. If ``is_aligned `` is ``True``, then m and n must be
+ equal.
+ mode (str): "iou" (intersection over union), "iof" (intersection
+ over foreground), or "giou" (generalized intersection over
+ union).
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
+ """
+ assert bboxes1.size(-1) in [0, 4, 5]
+ assert bboxes2.size(-1) in [0, 4, 5]
+ if bboxes2.size(-1) == 5:
+ bboxes2 = bboxes2[..., :4]
+ if bboxes1.size(-1) == 5:
+ bboxes1 = bboxes1[..., :4]
+ return bbox_overlaps(bboxes1, bboxes2, mode, is_aligned)
+
+ def __repr__(self):
+ """str: a string describing the module"""
+ repr_str = self.__class__.__name__ + '()'
+ return repr_str
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', is_aligned=False, eps=1e-6):
+ """Calculate overlap between two set of bboxes.
+
+ If ``is_aligned `` is ``False``, then calculate the overlaps between each
+ bbox of bboxes1 and bboxes2, otherwise the overlaps between each aligned
+ pair of bboxes1 and bboxes2.
+
+ Args:
+ bboxes1 (Tensor): shape (B, m, 4) in format or empty.
+ bboxes2 (Tensor): shape (B, n, 4) in format or empty.
+ B indicates the batch dim, in shape (B1, B2, ..., Bn).
+ If ``is_aligned `` is ``True``, then m and n must be equal.
+ mode (str): "iou" (intersection over union), "iof" (intersection over
+ foreground) or "giou" (generalized intersection over union).
+ Default "iou".
+ is_aligned (bool, optional): If True, then m and n must be equal.
+ Default False.
+ eps (float, optional): A value added to the denominator for numerical
+ stability. Default 1e-6.
+
+ Returns:
+ Tensor: shape (m, n) if ``is_aligned `` is False else shape (m,)
+
+ Example:
+ >>> bboxes1 = torch.FloatTensor([
+ >>> [0, 0, 10, 10],
+ >>> [10, 10, 20, 20],
+ >>> [32, 32, 38, 42],
+ >>> ])
+ >>> bboxes2 = torch.FloatTensor([
+ >>> [0, 0, 10, 20],
+ >>> [0, 10, 10, 19],
+ >>> [10, 10, 20, 20],
+ >>> ])
+ >>> overlaps = bbox_overlaps(bboxes1, bboxes2)
+ >>> assert overlaps.shape == (3, 3)
+ >>> overlaps = bbox_overlaps(bboxes1, bboxes2, is_aligned=True)
+ >>> assert overlaps.shape == (3, )
+
+ Example:
+ >>> empty = torch.empty(0, 4)
+ >>> nonempty = torch.FloatTensor([[0, 0, 10, 9]])
+ >>> assert tuple(bbox_overlaps(empty, nonempty).shape) == (0, 1)
+ >>> assert tuple(bbox_overlaps(nonempty, empty).shape) == (1, 0)
+ >>> assert tuple(bbox_overlaps(empty, empty).shape) == (0, 0)
+ """
+
+ assert mode in ['iou', 'iof', 'giou'], f'Unsupported mode {mode}'
+ # Either the boxes are empty or the length of boxes' last dimension is 4
+ assert (bboxes1.size(-1) == 4 or bboxes1.size(0) == 0)
+ assert (bboxes2.size(-1) == 4 or bboxes2.size(0) == 0)
+
+ # Batch dim must be the same
+ # Batch dim: (B1, B2, ... Bn)
+ assert bboxes1.shape[:-2] == bboxes2.shape[:-2]
+ batch_shape = bboxes1.shape[:-2]
+
+ rows = bboxes1.size(-2)
+ cols = bboxes2.size(-2)
+ if is_aligned:
+ assert rows == cols
+
+ if rows * cols == 0:
+ if is_aligned:
+ return bboxes1.new(batch_shape + (rows, ))
+ else:
+ return bboxes1.new(batch_shape + (rows, cols))
+
+ area1 = (bboxes1[..., 2] - bboxes1[..., 0]) * (
+ bboxes1[..., 3] - bboxes1[..., 1])
+ area2 = (bboxes2[..., 2] - bboxes2[..., 0]) * (
+ bboxes2[..., 3] - bboxes2[..., 1])
+
+ if is_aligned:
+ lt = torch.max(bboxes1[..., :2], bboxes2[..., :2]) # [B, rows, 2]
+ rb = torch.min(bboxes1[..., 2:], bboxes2[..., 2:]) # [B, rows, 2]
+
+ wh = (rb - lt).clamp(min=0) # [B, rows, 2]
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ['iou', 'giou']:
+ union = area1 + area2 - overlap
+ else:
+ union = area1
+ if mode == 'giou':
+ enclosed_lt = torch.min(bboxes1[..., :2], bboxes2[..., :2])
+ enclosed_rb = torch.max(bboxes1[..., 2:], bboxes2[..., 2:])
+ else:
+ lt = torch.max(bboxes1[..., :, None, :2],
+ bboxes2[..., None, :, :2]) # [B, rows, cols, 2]
+ rb = torch.min(bboxes1[..., :, None, 2:],
+ bboxes2[..., None, :, 2:]) # [B, rows, cols, 2]
+
+ wh = (rb - lt).clamp(min=0) # [B, rows, cols, 2]
+ overlap = wh[..., 0] * wh[..., 1]
+
+ if mode in ['iou', 'giou']:
+ union = area1[..., None] + area2[..., None, :] - overlap
+ else:
+ union = area1[..., None]
+ if mode == 'giou':
+ enclosed_lt = torch.min(bboxes1[..., :, None, :2],
+ bboxes2[..., None, :, :2])
+ enclosed_rb = torch.max(bboxes1[..., :, None, 2:],
+ bboxes2[..., None, :, 2:])
+
+ eps = union.new_tensor([eps])
+ union = torch.max(union, eps)
+ ious = overlap / union
+ if mode in ['iou', 'iof']:
+ return ious
+ # calculate gious
+ enclose_wh = (enclosed_rb - enclosed_lt).clamp(min=0)
+ enclose_area = enclose_wh[..., 0] * enclose_wh[..., 1]
+ enclose_area = torch.max(enclose_area, eps)
+ gious = ious - (enclose_area - union) / enclose_area
+ return gious
diff --git a/mmdet/core/bbox/match_costs/__init__.py b/mmdet/core/bbox/match_costs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..add5e0d394034d89b2d47c314ff1938294deb6ea
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/__init__.py
@@ -0,0 +1,7 @@
+from .builder import build_match_cost
+from .match_cost import BBoxL1Cost, ClassificationCost, FocalLossCost, IoUCost
+
+__all__ = [
+ 'build_match_cost', 'ClassificationCost', 'BBoxL1Cost', 'IoUCost',
+ 'FocalLossCost'
+]
diff --git a/mmdet/core/bbox/match_costs/builder.py b/mmdet/core/bbox/match_costs/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..6894017d42eb16ee4a8ae3ed660a71cda3ad9940
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/builder.py
@@ -0,0 +1,8 @@
+from mmcv.utils import Registry, build_from_cfg
+
+MATCH_COST = Registry('Match Cost')
+
+
+def build_match_cost(cfg, default_args=None):
+ """Builder of IoU calculator."""
+ return build_from_cfg(cfg, MATCH_COST, default_args)
diff --git a/mmdet/core/bbox/match_costs/match_cost.py b/mmdet/core/bbox/match_costs/match_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..38869737d66064ee5adea4b2c8ff26ae091e5f56
--- /dev/null
+++ b/mmdet/core/bbox/match_costs/match_cost.py
@@ -0,0 +1,184 @@
+import torch
+
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+from mmdet.core.bbox.transforms import bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh
+from .builder import MATCH_COST
+
+
+@MATCH_COST.register_module()
+class BBoxL1Cost(object):
+ """BBoxL1Cost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ box_format (str, optional): 'xyxy' for DETR, 'xywh' for Sparse_RCNN
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import BBoxL1Cost
+ >>> import torch
+ >>> self = BBoxL1Cost()
+ >>> bbox_pred = torch.rand(1, 4)
+ >>> gt_bboxes= torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(bbox_pred, gt_bboxes, factor)
+ tensor([[1.6172, 1.6422]])
+ """
+
+ def __init__(self, weight=1., box_format='xyxy'):
+ self.weight = weight
+ assert box_format in ['xyxy', 'xywh']
+ self.box_format = box_format
+
+ def __call__(self, bbox_pred, gt_bboxes):
+ """
+ Args:
+ bbox_pred (Tensor): Predicted boxes with normalized coordinates
+ (cx, cy, w, h), which are all in range [0, 1]. Shape
+ [num_query, 4].
+ gt_bboxes (Tensor): Ground truth boxes with normalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+
+ Returns:
+ torch.Tensor: bbox_cost value with weight
+ """
+ if self.box_format == 'xywh':
+ gt_bboxes = bbox_xyxy_to_cxcywh(gt_bboxes)
+ elif self.box_format == 'xyxy':
+ bbox_pred = bbox_cxcywh_to_xyxy(bbox_pred)
+ bbox_cost = torch.cdist(bbox_pred, gt_bboxes, p=1)
+ return bbox_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class FocalLossCost(object):
+ """FocalLossCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+ alpha (int | float, optional): focal_loss alpha
+ gamma (int | float, optional): focal_loss gamma
+ eps (float, optional): default 1e-12
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import FocalLossCost
+ >>> import torch
+ >>> self = FocalLossCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3236, -0.3364, -0.2699],
+ [-0.3439, -0.3209, -0.4807],
+ [-0.4099, -0.3795, -0.2929],
+ [-0.1950, -0.1207, -0.2626]])
+ """
+
+ def __init__(self, weight=1., alpha=0.25, gamma=2, eps=1e-12):
+ self.weight = weight
+ self.alpha = alpha
+ self.gamma = gamma
+ self.eps = eps
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ cls_pred = cls_pred.sigmoid()
+ neg_cost = -(1 - cls_pred + self.eps).log() * (
+ 1 - self.alpha) * cls_pred.pow(self.gamma)
+ pos_cost = -(cls_pred + self.eps).log() * self.alpha * (
+ 1 - cls_pred).pow(self.gamma)
+ cls_cost = pos_cost[:, gt_labels] - neg_cost[:, gt_labels]
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class ClassificationCost(object):
+ """ClsSoftmaxCost.
+
+ Args:
+ weight (int | float, optional): loss_weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import \
+ ... ClassificationCost
+ >>> import torch
+ >>> self = ClassificationCost()
+ >>> cls_pred = torch.rand(4, 3)
+ >>> gt_labels = torch.tensor([0, 1, 2])
+ >>> factor = torch.tensor([10, 8, 10, 8])
+ >>> self(cls_pred, gt_labels)
+ tensor([[-0.3430, -0.3525, -0.3045],
+ [-0.3077, -0.2931, -0.3992],
+ [-0.3664, -0.3455, -0.2881],
+ [-0.3343, -0.2701, -0.3956]])
+ """
+
+ def __init__(self, weight=1.):
+ self.weight = weight
+
+ def __call__(self, cls_pred, gt_labels):
+ """
+ Args:
+ cls_pred (Tensor): Predicted classification logits, shape
+ [num_query, num_class].
+ gt_labels (Tensor): Label of `gt_bboxes`, shape (num_gt,).
+
+ Returns:
+ torch.Tensor: cls_cost value with weight
+ """
+ # Following the official DETR repo, contrary to the loss that
+ # NLL is used, we approximate it in 1 - cls_score[gt_label].
+ # The 1 is a constant that doesn't change the matching,
+ # so it can be omitted.
+ cls_score = cls_pred.softmax(-1)
+ cls_cost = -cls_score[:, gt_labels]
+ return cls_cost * self.weight
+
+
+@MATCH_COST.register_module()
+class IoUCost(object):
+ """IoUCost.
+
+ Args:
+ iou_mode (str, optional): iou mode such as 'iou' | 'giou'
+ weight (int | float, optional): loss weight
+
+ Examples:
+ >>> from mmdet.core.bbox.match_costs.match_cost import IoUCost
+ >>> import torch
+ >>> self = IoUCost()
+ >>> bboxes = torch.FloatTensor([[1,1, 2, 2], [2, 2, 3, 4]])
+ >>> gt_bboxes = torch.FloatTensor([[0, 0, 2, 4], [1, 2, 3, 4]])
+ >>> self(bboxes, gt_bboxes)
+ tensor([[-0.1250, 0.1667],
+ [ 0.1667, -0.5000]])
+ """
+
+ def __init__(self, iou_mode='giou', weight=1.):
+ self.weight = weight
+ self.iou_mode = iou_mode
+
+ def __call__(self, bboxes, gt_bboxes):
+ """
+ Args:
+ bboxes (Tensor): Predicted boxes with unnormalized coordinates
+ (x1, y1, x2, y2). Shape [num_query, 4].
+ gt_bboxes (Tensor): Ground truth boxes with unnormalized
+ coordinates (x1, y1, x2, y2). Shape [num_gt, 4].
+
+ Returns:
+ torch.Tensor: iou_cost value with weight
+ """
+ # overlaps: [num_bboxes, num_gt]
+ overlaps = bbox_overlaps(
+ bboxes, gt_bboxes, mode=self.iou_mode, is_aligned=False)
+ # The 1 is a constant that doesn't change the matching, so omitted.
+ iou_cost = -overlaps
+ return iou_cost * self.weight
diff --git a/mmdet/core/bbox/samplers/__init__.py b/mmdet/core/bbox/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b06303fe1000e11c5486c40c70606a34a5208e3
--- /dev/null
+++ b/mmdet/core/bbox/samplers/__init__.py
@@ -0,0 +1,15 @@
+from .base_sampler import BaseSampler
+from .combined_sampler import CombinedSampler
+from .instance_balanced_pos_sampler import InstanceBalancedPosSampler
+from .iou_balanced_neg_sampler import IoUBalancedNegSampler
+from .ohem_sampler import OHEMSampler
+from .pseudo_sampler import PseudoSampler
+from .random_sampler import RandomSampler
+from .sampling_result import SamplingResult
+from .score_hlr_sampler import ScoreHLRSampler
+
+__all__ = [
+ 'BaseSampler', 'PseudoSampler', 'RandomSampler',
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler'
+]
diff --git a/mmdet/core/bbox/samplers/base_sampler.py b/mmdet/core/bbox/samplers/base_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ea35def115b49dfdad8a1f7c040ef3cd983b0d1
--- /dev/null
+++ b/mmdet/core/bbox/samplers/base_sampler.py
@@ -0,0 +1,101 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+
+from .sampling_result import SamplingResult
+
+
+class BaseSampler(metaclass=ABCMeta):
+ """Base class of samplers."""
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ self.num = num
+ self.pos_fraction = pos_fraction
+ self.neg_pos_ub = neg_pos_ub
+ self.add_gt_as_proposals = add_gt_as_proposals
+ self.pos_sampler = self
+ self.neg_sampler = self
+
+ @abstractmethod
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive samples."""
+ pass
+
+ @abstractmethod
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative samples."""
+ pass
+
+ def sample(self,
+ assign_result,
+ bboxes,
+ gt_bboxes,
+ gt_labels=None,
+ **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ :obj:`SamplingResult`: Sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox import RandomSampler
+ >>> from mmdet.core.bbox import AssignResult
+ >>> from mmdet.core.bbox.demodata import ensure_rng, random_boxes
+ >>> rng = ensure_rng(None)
+ >>> assign_result = AssignResult.random(rng=rng)
+ >>> bboxes = random_boxes(assign_result.num_preds, rng=rng)
+ >>> gt_bboxes = random_boxes(assign_result.num_gts, rng=rng)
+ >>> gt_labels = None
+ >>> self = RandomSampler(num=32, pos_fraction=0.5, neg_pos_ub=-1,
+ >>> add_gt_as_proposals=False)
+ >>> self = self.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ """
+ if len(bboxes.shape) < 2:
+ bboxes = bboxes[None, :]
+
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+ if self.add_gt_as_proposals and len(gt_bboxes) > 0:
+ if gt_labels is None:
+ raise ValueError(
+ 'gt_labels must be given when add_gt_as_proposals is True')
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(
+ assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
+ # We found that sampled indices have duplicated items occasionally.
+ # (may be a bug of PyTorch)
+ pos_inds = pos_inds.unique()
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds = self.neg_sampler._sample_neg(
+ assign_result, num_expected_neg, bboxes=bboxes, **kwargs)
+ neg_inds = neg_inds.unique()
+
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/mmdet/core/bbox/samplers/combined_sampler.py b/mmdet/core/bbox/samplers/combined_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..564729f0895b1863d94c479a67202438af45f996
--- /dev/null
+++ b/mmdet/core/bbox/samplers/combined_sampler.py
@@ -0,0 +1,20 @@
+from ..builder import BBOX_SAMPLERS, build_sampler
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class CombinedSampler(BaseSampler):
+ """A sampler that combines positive sampler and negative sampler."""
+
+ def __init__(self, pos_sampler, neg_sampler, **kwargs):
+ super(CombinedSampler, self).__init__(**kwargs)
+ self.pos_sampler = build_sampler(pos_sampler, **kwargs)
+ self.neg_sampler = build_sampler(neg_sampler, **kwargs)
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
diff --git a/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..c735298487e14e4a0ec42913f25673cccb98a8a0
--- /dev/null
+++ b/mmdet/core/bbox/samplers/instance_balanced_pos_sampler.py
@@ -0,0 +1,55 @@
+import numpy as np
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .random_sampler import RandomSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class InstanceBalancedPosSampler(RandomSampler):
+ """Instance balanced sampler that samples equal number of positive samples
+ for each instance."""
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Sample positive boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): The assigned results of boxes.
+ num_expected (int): The number of expected positive samples
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ unique_gt_inds = assign_result.gt_inds[pos_inds].unique()
+ num_gts = len(unique_gt_inds)
+ num_per_gt = int(round(num_expected / float(num_gts)) + 1)
+ sampled_inds = []
+ for i in unique_gt_inds:
+ inds = torch.nonzero(
+ assign_result.gt_inds == i.item(), as_tuple=False)
+ if inds.numel() != 0:
+ inds = inds.squeeze(1)
+ else:
+ continue
+ if len(inds) > num_per_gt:
+ inds = self.random_choice(inds, num_per_gt)
+ sampled_inds.append(inds)
+ sampled_inds = torch.cat(sampled_inds)
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(
+ list(set(pos_inds.cpu()) - set(sampled_inds.cpu())))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ extra_inds = torch.from_numpy(extra_inds).to(
+ assign_result.gt_inds.device).long()
+ sampled_inds = torch.cat([sampled_inds, extra_inds])
+ elif len(sampled_inds) > num_expected:
+ sampled_inds = self.random_choice(sampled_inds, num_expected)
+ return sampled_inds
diff --git a/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f275e430d1b57c4d9df57387b8f3ae6f0ff68cf1
--- /dev/null
+++ b/mmdet/core/bbox/samplers/iou_balanced_neg_sampler.py
@@ -0,0 +1,157 @@
+import numpy as np
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .random_sampler import RandomSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class IoUBalancedNegSampler(RandomSampler):
+ """IoU Balanced Sampling.
+
+ arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+
+ Sampling proposals according to their IoU. `floor_fraction` of needed RoIs
+ are sampled from proposals whose IoU are lower than `floor_thr` randomly.
+ The others are sampled from proposals whose IoU are higher than
+ `floor_thr`. These proposals are sampled from some bins evenly, which are
+ split by `num_bins` via IoU evenly.
+
+ Args:
+ num (int): number of proposals.
+ pos_fraction (float): fraction of positive proposals.
+ floor_thr (float): threshold (minimum) IoU for IoU balanced sampling,
+ set to -1 if all using IoU balanced sampling.
+ floor_fraction (float): sampling fraction of proposals under floor_thr.
+ num_bins (int): number of bins in IoU balanced sampling.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ floor_thr=-1,
+ floor_fraction=0,
+ num_bins=3,
+ **kwargs):
+ super(IoUBalancedNegSampler, self).__init__(num, pos_fraction,
+ **kwargs)
+ assert floor_thr >= 0 or floor_thr == -1
+ assert 0 <= floor_fraction <= 1
+ assert num_bins >= 1
+
+ self.floor_thr = floor_thr
+ self.floor_fraction = floor_fraction
+ self.num_bins = num_bins
+
+ def sample_via_interval(self, max_overlaps, full_set, num_expected):
+ """Sample according to the iou interval.
+
+ Args:
+ max_overlaps (torch.Tensor): IoU between bounding boxes and ground
+ truth boxes.
+ full_set (set(int)): A full set of indices of boxes。
+ num_expected (int): Number of expected samples。
+
+ Returns:
+ np.ndarray: Indices of samples
+ """
+ max_iou = max_overlaps.max()
+ iou_interval = (max_iou - self.floor_thr) / self.num_bins
+ per_num_expected = int(num_expected / self.num_bins)
+
+ sampled_inds = []
+ for i in range(self.num_bins):
+ start_iou = self.floor_thr + i * iou_interval
+ end_iou = self.floor_thr + (i + 1) * iou_interval
+ tmp_set = set(
+ np.where(
+ np.logical_and(max_overlaps >= start_iou,
+ max_overlaps < end_iou))[0])
+ tmp_inds = list(tmp_set & full_set)
+ if len(tmp_inds) > per_num_expected:
+ tmp_sampled_set = self.random_choice(tmp_inds,
+ per_num_expected)
+ else:
+ tmp_sampled_set = np.array(tmp_inds, dtype=np.int)
+ sampled_inds.append(tmp_sampled_set)
+
+ sampled_inds = np.concatenate(sampled_inds)
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(list(full_set - set(sampled_inds)))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ sampled_inds = np.concatenate([sampled_inds, extra_inds])
+
+ return sampled_inds
+
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Sample negative boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): The assigned results of boxes.
+ num_expected (int): The number of expected negative samples
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ max_overlaps = assign_result.max_overlaps.cpu().numpy()
+ # balance sampling for negative samples
+ neg_set = set(neg_inds.cpu().numpy())
+
+ if self.floor_thr > 0:
+ floor_set = set(
+ np.where(
+ np.logical_and(max_overlaps >= 0,
+ max_overlaps < self.floor_thr))[0])
+ iou_sampling_set = set(
+ np.where(max_overlaps >= self.floor_thr)[0])
+ elif self.floor_thr == 0:
+ floor_set = set(np.where(max_overlaps == 0)[0])
+ iou_sampling_set = set(
+ np.where(max_overlaps > self.floor_thr)[0])
+ else:
+ floor_set = set()
+ iou_sampling_set = set(
+ np.where(max_overlaps > self.floor_thr)[0])
+ # for sampling interval calculation
+ self.floor_thr = 0
+
+ floor_neg_inds = list(floor_set & neg_set)
+ iou_sampling_neg_inds = list(iou_sampling_set & neg_set)
+ num_expected_iou_sampling = int(num_expected *
+ (1 - self.floor_fraction))
+ if len(iou_sampling_neg_inds) > num_expected_iou_sampling:
+ if self.num_bins >= 2:
+ iou_sampled_inds = self.sample_via_interval(
+ max_overlaps, set(iou_sampling_neg_inds),
+ num_expected_iou_sampling)
+ else:
+ iou_sampled_inds = self.random_choice(
+ iou_sampling_neg_inds, num_expected_iou_sampling)
+ else:
+ iou_sampled_inds = np.array(
+ iou_sampling_neg_inds, dtype=np.int)
+ num_expected_floor = num_expected - len(iou_sampled_inds)
+ if len(floor_neg_inds) > num_expected_floor:
+ sampled_floor_inds = self.random_choice(
+ floor_neg_inds, num_expected_floor)
+ else:
+ sampled_floor_inds = np.array(floor_neg_inds, dtype=np.int)
+ sampled_inds = np.concatenate(
+ (sampled_floor_inds, iou_sampled_inds))
+ if len(sampled_inds) < num_expected:
+ num_extra = num_expected - len(sampled_inds)
+ extra_inds = np.array(list(neg_set - set(sampled_inds)))
+ if len(extra_inds) > num_extra:
+ extra_inds = self.random_choice(extra_inds, num_extra)
+ sampled_inds = np.concatenate((sampled_inds, extra_inds))
+ sampled_inds = torch.from_numpy(sampled_inds).long().to(
+ assign_result.gt_inds.device)
+ return sampled_inds
diff --git a/mmdet/core/bbox/samplers/ohem_sampler.py b/mmdet/core/bbox/samplers/ohem_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b99f60ef0176f1b7a56665fb0f59272f65b84cd
--- /dev/null
+++ b/mmdet/core/bbox/samplers/ohem_sampler.py
@@ -0,0 +1,107 @@
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from ..transforms import bbox2roi
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class OHEMSampler(BaseSampler):
+ r"""Online Hard Example Mining Sampler described in `Training Region-based
+ Object Detectors with Online Hard Example Mining
+ `_.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ context,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ super(OHEMSampler, self).__init__(num, pos_fraction, neg_pos_ub,
+ add_gt_as_proposals)
+ self.context = context
+ if not hasattr(self.context, 'num_stages'):
+ self.bbox_head = self.context.bbox_head
+ else:
+ self.bbox_head = self.context.bbox_head[self.context.current_stage]
+
+ def hard_mining(self, inds, num_expected, bboxes, labels, feats):
+ with torch.no_grad():
+ rois = bbox2roi([bboxes])
+ if not hasattr(self.context, 'num_stages'):
+ bbox_results = self.context._bbox_forward(feats, rois)
+ else:
+ bbox_results = self.context._bbox_forward(
+ self.context.current_stage, feats, rois)
+ cls_score = bbox_results['cls_score']
+ loss = self.bbox_head.loss(
+ cls_score=cls_score,
+ bbox_pred=None,
+ rois=rois,
+ labels=labels,
+ label_weights=cls_score.new_ones(cls_score.size(0)),
+ bbox_targets=None,
+ bbox_weights=None,
+ reduction_override='none')['loss_cls']
+ _, topk_loss_inds = loss.topk(num_expected)
+ return inds[topk_loss_inds]
+
+ def _sample_pos(self,
+ assign_result,
+ num_expected,
+ bboxes=None,
+ feats=None,
+ **kwargs):
+ """Sample positive boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ num_expected (int): Number of expected positive samples
+ bboxes (torch.Tensor, optional): Boxes. Defaults to None.
+ feats (list[torch.Tensor], optional): Multi-level features.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Indices of positive samples
+ """
+ # Sample some hard positive samples
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.hard_mining(pos_inds, num_expected, bboxes[pos_inds],
+ assign_result.labels[pos_inds], feats)
+
+ def _sample_neg(self,
+ assign_result,
+ num_expected,
+ bboxes=None,
+ feats=None,
+ **kwargs):
+ """Sample negative boxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ num_expected (int): Number of expected negative samples
+ bboxes (torch.Tensor, optional): Boxes. Defaults to None.
+ feats (list[torch.Tensor], optional): Multi-level features.
+ Defaults to None.
+
+ Returns:
+ torch.Tensor: Indices of negative samples
+ """
+ # Sample some hard negative samples
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ neg_labels = assign_result.labels.new_empty(
+ neg_inds.size(0)).fill_(self.bbox_head.num_classes)
+ return self.hard_mining(neg_inds, num_expected, bboxes[neg_inds],
+ neg_labels, feats)
diff --git a/mmdet/core/bbox/samplers/pseudo_sampler.py b/mmdet/core/bbox/samplers/pseudo_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bd81abcdc62debc14772659d7a171f20bf33364
--- /dev/null
+++ b/mmdet/core/bbox/samplers/pseudo_sampler.py
@@ -0,0 +1,41 @@
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+from .sampling_result import SamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class PseudoSampler(BaseSampler):
+ """A pseudo sampler that does not do sampling actually."""
+
+ def __init__(self, **kwargs):
+ pass
+
+ def _sample_pos(self, **kwargs):
+ """Sample positive samples."""
+ raise NotImplementedError
+
+ def _sample_neg(self, **kwargs):
+ """Sample negative samples."""
+ raise NotImplementedError
+
+ def sample(self, assign_result, bboxes, gt_bboxes, **kwargs):
+ """Directly returns the positive and negative indices of samples.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Assigned results
+ bboxes (torch.Tensor): Bounding boxes
+ gt_bboxes (torch.Tensor): Ground truth boxes
+
+ Returns:
+ :obj:`SamplingResult`: sampler results
+ """
+ pos_inds = torch.nonzero(
+ assign_result.gt_inds > 0, as_tuple=False).squeeze(-1).unique()
+ neg_inds = torch.nonzero(
+ assign_result.gt_inds == 0, as_tuple=False).squeeze(-1).unique()
+ gt_flags = bboxes.new_zeros(bboxes.shape[0], dtype=torch.uint8)
+ sampling_result = SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags)
+ return sampling_result
diff --git a/mmdet/core/bbox/samplers/random_sampler.py b/mmdet/core/bbox/samplers/random_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f34b006e8bb0b55c74aa1c3b792f3664ada93162
--- /dev/null
+++ b/mmdet/core/bbox/samplers/random_sampler.py
@@ -0,0 +1,78 @@
+import torch
+
+from ..builder import BBOX_SAMPLERS
+from .base_sampler import BaseSampler
+
+
+@BBOX_SAMPLERS.register_module()
+class RandomSampler(BaseSampler):
+ """Random sampler.
+
+ Args:
+ num (int): Number of samples
+ pos_fraction (float): Fraction of positive samples
+ neg_pos_up (int, optional): Upper bound number of negative and
+ positive samples. Defaults to -1.
+ add_gt_as_proposals (bool, optional): Whether to add ground truth
+ boxes as proposals. Defaults to True.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ **kwargs):
+ from mmdet.core.bbox import demodata
+ super(RandomSampler, self).__init__(num, pos_fraction, neg_pos_ub,
+ add_gt_as_proposals)
+ self.rng = demodata.ensure_rng(kwargs.get('rng', None))
+
+ def random_choice(self, gallery, num):
+ """Random select some elements from the gallery.
+
+ If `gallery` is a Tensor, the returned indices will be a Tensor;
+ If `gallery` is a ndarray or list, the returned indices will be a
+ ndarray.
+
+ Args:
+ gallery (Tensor | ndarray | list): indices pool.
+ num (int): expected sample num.
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ assert len(gallery) >= num
+
+ is_tensor = isinstance(gallery, torch.Tensor)
+ if not is_tensor:
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ else:
+ device = 'cpu'
+ gallery = torch.tensor(gallery, dtype=torch.long, device=device)
+ perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
+ rand_inds = gallery[perm]
+ if not is_tensor:
+ rand_inds = rand_inds.cpu().numpy()
+ return rand_inds
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some positive samples."""
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0, as_tuple=False)
+ if pos_inds.numel() != 0:
+ pos_inds = pos_inds.squeeze(1)
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.random_choice(pos_inds, num_expected)
+
+ def _sample_neg(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some negative samples."""
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0, as_tuple=False)
+ if neg_inds.numel() != 0:
+ neg_inds = neg_inds.squeeze(1)
+ if len(neg_inds) <= num_expected:
+ return neg_inds
+ else:
+ return self.random_choice(neg_inds, num_expected)
diff --git a/mmdet/core/bbox/samplers/sampling_result.py b/mmdet/core/bbox/samplers/sampling_result.py
new file mode 100644
index 0000000000000000000000000000000000000000..419a8e39a3c307a7cd9cfd0565a20037ded0d646
--- /dev/null
+++ b/mmdet/core/bbox/samplers/sampling_result.py
@@ -0,0 +1,152 @@
+import torch
+
+from mmdet.utils import util_mixins
+
+
+class SamplingResult(util_mixins.NiceRepr):
+ """Bbox sampling result.
+
+ Example:
+ >>> # xdoctest: +IGNORE_WANT
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random(rng=10)
+ >>> print(f'self = {self}')
+ self =
+ """
+
+ def __init__(self, pos_inds, neg_inds, bboxes, gt_bboxes, assign_result,
+ gt_flags):
+ self.pos_inds = pos_inds
+ self.neg_inds = neg_inds
+ self.pos_bboxes = bboxes[pos_inds]
+ self.neg_bboxes = bboxes[neg_inds]
+ self.pos_is_gt = gt_flags[pos_inds]
+
+ self.num_gts = gt_bboxes.shape[0]
+ self.pos_assigned_gt_inds = assign_result.gt_inds[pos_inds] - 1
+
+ if gt_bboxes.numel() == 0:
+ # hack for index error case
+ assert self.pos_assigned_gt_inds.numel() == 0
+ self.pos_gt_bboxes = torch.empty_like(gt_bboxes).view(-1, 4)
+ else:
+ if len(gt_bboxes.shape) < 2:
+ gt_bboxes = gt_bboxes.view(-1, 4)
+
+ self.pos_gt_bboxes = gt_bboxes[self.pos_assigned_gt_inds, :]
+
+ if assign_result.labels is not None:
+ self.pos_gt_labels = assign_result.labels[pos_inds]
+ else:
+ self.pos_gt_labels = None
+
+ @property
+ def bboxes(self):
+ """torch.Tensor: concatenated positive and negative boxes"""
+ return torch.cat([self.pos_bboxes, self.neg_bboxes])
+
+ def to(self, device):
+ """Change the device of the data inplace.
+
+ Example:
+ >>> self = SamplingResult.random()
+ >>> print(f'self = {self.to(None)}')
+ >>> # xdoctest: +REQUIRES(--gpu)
+ >>> print(f'self = {self.to(0)}')
+ """
+ _dict = self.__dict__
+ for key, value in _dict.items():
+ if isinstance(value, torch.Tensor):
+ _dict[key] = value.to(device)
+ return self
+
+ def __nice__(self):
+ data = self.info.copy()
+ data['pos_bboxes'] = data.pop('pos_bboxes').shape
+ data['neg_bboxes'] = data.pop('neg_bboxes').shape
+ parts = [f"'{k}': {v!r}" for k, v in sorted(data.items())]
+ body = ' ' + ',\n '.join(parts)
+ return '{\n' + body + '\n}'
+
+ @property
+ def info(self):
+ """Returns a dictionary of info about the object."""
+ return {
+ 'pos_inds': self.pos_inds,
+ 'neg_inds': self.neg_inds,
+ 'pos_bboxes': self.pos_bboxes,
+ 'neg_bboxes': self.neg_bboxes,
+ 'pos_is_gt': self.pos_is_gt,
+ 'num_gts': self.num_gts,
+ 'pos_assigned_gt_inds': self.pos_assigned_gt_inds,
+ }
+
+ @classmethod
+ def random(cls, rng=None, **kwargs):
+ """
+ Args:
+ rng (None | int | numpy.random.RandomState): seed or state.
+ kwargs (keyword arguments):
+ - num_preds: number of predicted boxes
+ - num_gts: number of true boxes
+ - p_ignore (float): probability of a predicted box assinged to \
+ an ignored truth.
+ - p_assigned (float): probability of a predicted box not being \
+ assigned.
+ - p_use_label (float | bool): with labels or not.
+
+ Returns:
+ :obj:`SamplingResult`: Randomly generated sampling result.
+
+ Example:
+ >>> from mmdet.core.bbox.samplers.sampling_result import * # NOQA
+ >>> self = SamplingResult.random()
+ >>> print(self.__dict__)
+ """
+ from mmdet.core.bbox.samplers.random_sampler import RandomSampler
+ from mmdet.core.bbox.assigners.assign_result import AssignResult
+ from mmdet.core.bbox import demodata
+ rng = demodata.ensure_rng(rng)
+
+ # make probabalistic?
+ num = 32
+ pos_fraction = 0.5
+ neg_pos_ub = -1
+
+ assign_result = AssignResult.random(rng=rng, **kwargs)
+
+ # Note we could just compute an assignment
+ bboxes = demodata.random_boxes(assign_result.num_preds, rng=rng)
+ gt_bboxes = demodata.random_boxes(assign_result.num_gts, rng=rng)
+
+ if rng.rand() > 0.2:
+ # sometimes algorithms squeeze their data, be robust to that
+ gt_bboxes = gt_bboxes.squeeze()
+ bboxes = bboxes.squeeze()
+
+ if assign_result.labels is None:
+ gt_labels = None
+ else:
+ gt_labels = None # todo
+
+ if gt_labels is None:
+ add_gt_as_proposals = False
+ else:
+ add_gt_as_proposals = True # make probabalistic?
+
+ sampler = RandomSampler(
+ num,
+ pos_fraction,
+ neg_pos_ub=neg_pos_ub,
+ add_gt_as_proposals=add_gt_as_proposals,
+ rng=rng)
+ self = sampler.sample(assign_result, bboxes, gt_bboxes, gt_labels)
+ return self
diff --git a/mmdet/core/bbox/samplers/score_hlr_sampler.py b/mmdet/core/bbox/samplers/score_hlr_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d46b97705db60fb6a4eb5fa7da10ac78acb8bc
--- /dev/null
+++ b/mmdet/core/bbox/samplers/score_hlr_sampler.py
@@ -0,0 +1,264 @@
+import torch
+from mmcv.ops import nms_match
+
+from ..builder import BBOX_SAMPLERS
+from ..transforms import bbox2roi
+from .base_sampler import BaseSampler
+from .sampling_result import SamplingResult
+
+
+@BBOX_SAMPLERS.register_module()
+class ScoreHLRSampler(BaseSampler):
+ r"""Importance-based Sample Reweighting (ISR_N), described in `Prime Sample
+ Attention in Object Detection `_.
+
+ Score hierarchical local rank (HLR) differentiates with RandomSampler in
+ negative part. It firstly computes Score-HLR in a two-step way,
+ then linearly maps score hlr to the loss weights.
+
+ Args:
+ num (int): Total number of sampled RoIs.
+ pos_fraction (float): Fraction of positive samples.
+ context (:class:`BaseRoIHead`): RoI head that the sampler belongs to.
+ neg_pos_ub (int): Upper bound of the ratio of num negative to num
+ positive, -1 means no upper bound.
+ add_gt_as_proposals (bool): Whether to add ground truth as proposals.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ score_thr (float): Minimum score that a negative sample is to be
+ considered as valid bbox.
+ """
+
+ def __init__(self,
+ num,
+ pos_fraction,
+ context,
+ neg_pos_ub=-1,
+ add_gt_as_proposals=True,
+ k=0.5,
+ bias=0,
+ score_thr=0.05,
+ iou_thr=0.5,
+ **kwargs):
+ super().__init__(num, pos_fraction, neg_pos_ub, add_gt_as_proposals)
+ self.k = k
+ self.bias = bias
+ self.score_thr = score_thr
+ self.iou_thr = iou_thr
+ self.context = context
+ # context of cascade detectors is a list, so distinguish them here.
+ if not hasattr(context, 'num_stages'):
+ self.bbox_roi_extractor = context.bbox_roi_extractor
+ self.bbox_head = context.bbox_head
+ self.with_shared_head = context.with_shared_head
+ if self.with_shared_head:
+ self.shared_head = context.shared_head
+ else:
+ self.bbox_roi_extractor = context.bbox_roi_extractor[
+ context.current_stage]
+ self.bbox_head = context.bbox_head[context.current_stage]
+
+ @staticmethod
+ def random_choice(gallery, num):
+ """Randomly select some elements from the gallery.
+
+ If `gallery` is a Tensor, the returned indices will be a Tensor;
+ If `gallery` is a ndarray or list, the returned indices will be a
+ ndarray.
+
+ Args:
+ gallery (Tensor | ndarray | list): indices pool.
+ num (int): expected sample num.
+
+ Returns:
+ Tensor or ndarray: sampled indices.
+ """
+ assert len(gallery) >= num
+
+ is_tensor = isinstance(gallery, torch.Tensor)
+ if not is_tensor:
+ if torch.cuda.is_available():
+ device = torch.cuda.current_device()
+ else:
+ device = 'cpu'
+ gallery = torch.tensor(gallery, dtype=torch.long, device=device)
+ perm = torch.randperm(gallery.numel(), device=gallery.device)[:num]
+ rand_inds = gallery[perm]
+ if not is_tensor:
+ rand_inds = rand_inds.cpu().numpy()
+ return rand_inds
+
+ def _sample_pos(self, assign_result, num_expected, **kwargs):
+ """Randomly sample some positive samples."""
+ pos_inds = torch.nonzero(assign_result.gt_inds > 0).flatten()
+ if pos_inds.numel() <= num_expected:
+ return pos_inds
+ else:
+ return self.random_choice(pos_inds, num_expected)
+
+ def _sample_neg(self,
+ assign_result,
+ num_expected,
+ bboxes,
+ feats=None,
+ img_meta=None,
+ **kwargs):
+ """Sample negative samples.
+
+ Score-HLR sampler is done in the following steps:
+ 1. Take the maximum positive score prediction of each negative samples
+ as s_i.
+ 2. Filter out negative samples whose s_i <= score_thr, the left samples
+ are called valid samples.
+ 3. Use NMS-Match to divide valid samples into different groups,
+ samples in the same group will greatly overlap with each other
+ 4. Rank the matched samples in two-steps to get Score-HLR.
+ (1) In the same group, rank samples with their scores.
+ (2) In the same score rank across different groups,
+ rank samples with their scores again.
+ 5. Linearly map Score-HLR to the final label weights.
+
+ Args:
+ assign_result (:obj:`AssignResult`): result of assigner.
+ num_expected (int): Expected number of samples.
+ bboxes (Tensor): bbox to be sampled.
+ feats (Tensor): Features come from FPN.
+ img_meta (dict): Meta information dictionary.
+ """
+ neg_inds = torch.nonzero(assign_result.gt_inds == 0).flatten()
+ num_neg = neg_inds.size(0)
+ if num_neg == 0:
+ return neg_inds, None
+ with torch.no_grad():
+ neg_bboxes = bboxes[neg_inds]
+ neg_rois = bbox2roi([neg_bboxes])
+ bbox_result = self.context._bbox_forward(feats, neg_rois)
+ cls_score, bbox_pred = bbox_result['cls_score'], bbox_result[
+ 'bbox_pred']
+
+ ori_loss = self.bbox_head.loss(
+ cls_score=cls_score,
+ bbox_pred=None,
+ rois=None,
+ labels=neg_inds.new_full((num_neg, ),
+ self.bbox_head.num_classes),
+ label_weights=cls_score.new_ones(num_neg),
+ bbox_targets=None,
+ bbox_weights=None,
+ reduction_override='none')['loss_cls']
+
+ # filter out samples with the max score lower than score_thr
+ max_score, argmax_score = cls_score.softmax(-1)[:, :-1].max(-1)
+ valid_inds = (max_score > self.score_thr).nonzero().view(-1)
+ invalid_inds = (max_score <= self.score_thr).nonzero().view(-1)
+ num_valid = valid_inds.size(0)
+ num_invalid = invalid_inds.size(0)
+
+ num_expected = min(num_neg, num_expected)
+ num_hlr = min(num_valid, num_expected)
+ num_rand = num_expected - num_hlr
+ if num_valid > 0:
+ valid_rois = neg_rois[valid_inds]
+ valid_max_score = max_score[valid_inds]
+ valid_argmax_score = argmax_score[valid_inds]
+ valid_bbox_pred = bbox_pred[valid_inds]
+
+ # valid_bbox_pred shape: [num_valid, #num_classes, 4]
+ valid_bbox_pred = valid_bbox_pred.view(
+ valid_bbox_pred.size(0), -1, 4)
+ selected_bbox_pred = valid_bbox_pred[range(num_valid),
+ valid_argmax_score]
+ pred_bboxes = self.bbox_head.bbox_coder.decode(
+ valid_rois[:, 1:], selected_bbox_pred)
+ pred_bboxes_with_score = torch.cat(
+ [pred_bboxes, valid_max_score[:, None]], -1)
+ group = nms_match(pred_bboxes_with_score, self.iou_thr)
+
+ # imp: importance
+ imp = cls_score.new_zeros(num_valid)
+ for g in group:
+ g_score = valid_max_score[g]
+ # g_score has already sorted
+ rank = g_score.new_tensor(range(g_score.size(0)))
+ imp[g] = num_valid - rank + g_score
+ _, imp_rank_inds = imp.sort(descending=True)
+ _, imp_rank = imp_rank_inds.sort()
+ hlr_inds = imp_rank_inds[:num_expected]
+
+ if num_rand > 0:
+ rand_inds = torch.randperm(num_invalid)[:num_rand]
+ select_inds = torch.cat(
+ [valid_inds[hlr_inds], invalid_inds[rand_inds]])
+ else:
+ select_inds = valid_inds[hlr_inds]
+
+ neg_label_weights = cls_score.new_ones(num_expected)
+
+ up_bound = max(num_expected, num_valid)
+ imp_weights = (up_bound -
+ imp_rank[hlr_inds].float()) / up_bound
+ neg_label_weights[:num_hlr] = imp_weights
+ neg_label_weights[num_hlr:] = imp_weights.min()
+ neg_label_weights = (self.bias +
+ (1 - self.bias) * neg_label_weights).pow(
+ self.k)
+ ori_selected_loss = ori_loss[select_inds]
+ new_loss = ori_selected_loss * neg_label_weights
+ norm_ratio = ori_selected_loss.sum() / new_loss.sum()
+ neg_label_weights *= norm_ratio
+ else:
+ neg_label_weights = cls_score.new_ones(num_expected)
+ select_inds = torch.randperm(num_neg)[:num_expected]
+
+ return neg_inds[select_inds], neg_label_weights
+
+ def sample(self,
+ assign_result,
+ bboxes,
+ gt_bboxes,
+ gt_labels=None,
+ img_meta=None,
+ **kwargs):
+ """Sample positive and negative bboxes.
+
+ This is a simple implementation of bbox sampling given candidates,
+ assigning results and ground truth bboxes.
+
+ Args:
+ assign_result (:obj:`AssignResult`): Bbox assigning results.
+ bboxes (Tensor): Boxes to be sampled from.
+ gt_bboxes (Tensor): Ground truth bboxes.
+ gt_labels (Tensor, optional): Class labels of ground truth bboxes.
+
+ Returns:
+ tuple[:obj:`SamplingResult`, Tensor]: Sampling result and negetive
+ label weights.
+ """
+ bboxes = bboxes[:, :4]
+
+ gt_flags = bboxes.new_zeros((bboxes.shape[0], ), dtype=torch.uint8)
+ if self.add_gt_as_proposals:
+ bboxes = torch.cat([gt_bboxes, bboxes], dim=0)
+ assign_result.add_gt_(gt_labels)
+ gt_ones = bboxes.new_ones(gt_bboxes.shape[0], dtype=torch.uint8)
+ gt_flags = torch.cat([gt_ones, gt_flags])
+
+ num_expected_pos = int(self.num * self.pos_fraction)
+ pos_inds = self.pos_sampler._sample_pos(
+ assign_result, num_expected_pos, bboxes=bboxes, **kwargs)
+ num_sampled_pos = pos_inds.numel()
+ num_expected_neg = self.num - num_sampled_pos
+ if self.neg_pos_ub >= 0:
+ _pos = max(1, num_sampled_pos)
+ neg_upper_bound = int(self.neg_pos_ub * _pos)
+ if num_expected_neg > neg_upper_bound:
+ num_expected_neg = neg_upper_bound
+ neg_inds, neg_label_weights = self.neg_sampler._sample_neg(
+ assign_result,
+ num_expected_neg,
+ bboxes,
+ img_meta=img_meta,
+ **kwargs)
+
+ return SamplingResult(pos_inds, neg_inds, bboxes, gt_bboxes,
+ assign_result, gt_flags), neg_label_weights
diff --git a/mmdet/core/bbox/transforms.py b/mmdet/core/bbox/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..df55b0a496516bf7373fe96cf746c561dd713c3b
--- /dev/null
+++ b/mmdet/core/bbox/transforms.py
@@ -0,0 +1,240 @@
+import numpy as np
+import torch
+
+
+def bbox_flip(bboxes, img_shape, direction='horizontal'):
+ """Flip bboxes horizontally or vertically.
+
+ Args:
+ bboxes (Tensor): Shape (..., 4*k)
+ img_shape (tuple): Image shape.
+ direction (str): Flip direction, options are "horizontal", "vertical",
+ "diagonal". Default: "horizontal"
+
+ Returns:
+ Tensor: Flipped bboxes.
+ """
+ assert bboxes.shape[-1] % 4 == 0
+ assert direction in ['horizontal', 'vertical', 'diagonal']
+ flipped = bboxes.clone()
+ if direction == 'horizontal':
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ else:
+ flipped[..., 0::4] = img_shape[1] - bboxes[..., 2::4]
+ flipped[..., 1::4] = img_shape[0] - bboxes[..., 3::4]
+ flipped[..., 2::4] = img_shape[1] - bboxes[..., 0::4]
+ flipped[..., 3::4] = img_shape[0] - bboxes[..., 1::4]
+ return flipped
+
+
+def bbox_mapping(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from the original image scale to testing scale."""
+ new_bboxes = bboxes * bboxes.new_tensor(scale_factor)
+ if flip:
+ new_bboxes = bbox_flip(new_bboxes, img_shape, flip_direction)
+ return new_bboxes
+
+
+def bbox_mapping_back(bboxes,
+ img_shape,
+ scale_factor,
+ flip,
+ flip_direction='horizontal'):
+ """Map bboxes from testing scale to original image scale."""
+ new_bboxes = bbox_flip(bboxes, img_shape,
+ flip_direction) if flip else bboxes
+ new_bboxes = new_bboxes.view(-1, 4) / new_bboxes.new_tensor(scale_factor)
+ return new_bboxes.view(bboxes.shape)
+
+
+def bbox2roi(bbox_list):
+ """Convert a list of bboxes to roi format.
+
+ Args:
+ bbox_list (list[Tensor]): a list of bboxes corresponding to a batch
+ of images.
+
+ Returns:
+ Tensor: shape (n, 5), [batch_ind, x1, y1, x2, y2]
+ """
+ rois_list = []
+ for img_id, bboxes in enumerate(bbox_list):
+ if bboxes.size(0) > 0:
+ img_inds = bboxes.new_full((bboxes.size(0), 1), img_id)
+ rois = torch.cat([img_inds, bboxes[:, :4]], dim=-1)
+ else:
+ rois = bboxes.new_zeros((0, 5))
+ rois_list.append(rois)
+ rois = torch.cat(rois_list, 0)
+ return rois
+
+
+def roi2bbox(rois):
+ """Convert rois to bounding box format.
+
+ Args:
+ rois (torch.Tensor): RoIs with the shape (n, 5) where the first
+ column indicates batch id of each RoI.
+
+ Returns:
+ list[torch.Tensor]: Converted boxes of corresponding rois.
+ """
+ bbox_list = []
+ img_ids = torch.unique(rois[:, 0].cpu(), sorted=True)
+ for img_id in img_ids:
+ inds = (rois[:, 0] == img_id.item())
+ bbox = rois[inds, 1:]
+ bbox_list.append(bbox)
+ return bbox_list
+
+
+def bbox2result(bboxes, labels, num_classes):
+ """Convert detection results to a list of numpy arrays.
+
+ Args:
+ bboxes (torch.Tensor | np.ndarray): shape (n, 5)
+ labels (torch.Tensor | np.ndarray): shape (n, )
+ num_classes (int): class number, including background class
+
+ Returns:
+ list(ndarray): bbox results of each class
+ """
+ if bboxes.shape[0] == 0:
+ return [np.zeros((0, 5), dtype=np.float32) for i in range(num_classes)]
+ else:
+ if isinstance(bboxes, torch.Tensor):
+ bboxes = bboxes.detach().cpu().numpy()
+ labels = labels.detach().cpu().numpy()
+ return [bboxes[labels == i, :] for i in range(num_classes)]
+
+
+def distance2bbox(points, distance, max_shape=None):
+ """Decode distance prediction to bounding box.
+
+ Args:
+ points (Tensor): Shape (B, N, 2) or (N, 2).
+ distance (Tensor): Distance from the given point to 4
+ boundaries (left, top, right, bottom). Shape (B, N, 4) or (N, 4)
+ max_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]],optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If priors shape is (B, N, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+
+ Returns:
+ Tensor: Boxes with shape (N, 4) or (B, N, 4)
+ """
+ x1 = points[..., 0] - distance[..., 0]
+ y1 = points[..., 1] - distance[..., 1]
+ x2 = points[..., 0] + distance[..., 2]
+ y2 = points[..., 1] + distance[..., 3]
+
+ bboxes = torch.stack([x1, y1, x2, y2], -1)
+
+ if max_shape is not None:
+ if not isinstance(max_shape, torch.Tensor):
+ max_shape = x1.new_tensor(max_shape)
+ max_shape = max_shape[..., :2].type_as(x1)
+ if max_shape.ndim == 2:
+ assert bboxes.ndim == 3
+ assert max_shape.size(0) == bboxes.size(0)
+
+ min_xy = x1.new_tensor(0)
+ max_xy = torch.cat([max_shape, max_shape],
+ dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ return bboxes
+
+
+def bbox2distance(points, bbox, max_dis=None, eps=0.1):
+ """Decode bounding box based on distances.
+
+ Args:
+ points (Tensor): Shape (n, 2), [x, y].
+ bbox (Tensor): Shape (n, 4), "xyxy" format
+ max_dis (float): Upper bound of the distance.
+ eps (float): a small value to ensure target < max_dis, instead <=
+
+ Returns:
+ Tensor: Decoded distances.
+ """
+ left = points[:, 0] - bbox[:, 0]
+ top = points[:, 1] - bbox[:, 1]
+ right = bbox[:, 2] - points[:, 0]
+ bottom = bbox[:, 3] - points[:, 1]
+ if max_dis is not None:
+ left = left.clamp(min=0, max=max_dis - eps)
+ top = top.clamp(min=0, max=max_dis - eps)
+ right = right.clamp(min=0, max=max_dis - eps)
+ bottom = bottom.clamp(min=0, max=max_dis - eps)
+ return torch.stack([left, top, right, bottom], -1)
+
+
+def bbox_rescale(bboxes, scale_factor=1.0):
+ """Rescale bounding box w.r.t. scale_factor.
+
+ Args:
+ bboxes (Tensor): Shape (n, 4) for bboxes or (n, 5) for rois
+ scale_factor (float): rescale factor
+
+ Returns:
+ Tensor: Rescaled bboxes.
+ """
+ if bboxes.size(1) == 5:
+ bboxes_ = bboxes[:, 1:]
+ inds_ = bboxes[:, 0]
+ else:
+ bboxes_ = bboxes
+ cx = (bboxes_[:, 0] + bboxes_[:, 2]) * 0.5
+ cy = (bboxes_[:, 1] + bboxes_[:, 3]) * 0.5
+ w = bboxes_[:, 2] - bboxes_[:, 0]
+ h = bboxes_[:, 3] - bboxes_[:, 1]
+ w = w * scale_factor
+ h = h * scale_factor
+ x1 = cx - 0.5 * w
+ x2 = cx + 0.5 * w
+ y1 = cy - 0.5 * h
+ y2 = cy + 0.5 * h
+ if bboxes.size(1) == 5:
+ rescaled_bboxes = torch.stack([inds_, x1, y1, x2, y2], dim=-1)
+ else:
+ rescaled_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ return rescaled_bboxes
+
+
+def bbox_cxcywh_to_xyxy(bbox):
+ """Convert bbox coordinates from (cx, cy, w, h) to (x1, y1, x2, y2).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ cx, cy, w, h = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(cx - 0.5 * w), (cy - 0.5 * h), (cx + 0.5 * w), (cy + 0.5 * h)]
+ return torch.cat(bbox_new, dim=-1)
+
+
+def bbox_xyxy_to_cxcywh(bbox):
+ """Convert bbox coordinates from (x1, y1, x2, y2) to (cx, cy, w, h).
+
+ Args:
+ bbox (Tensor): Shape (n, 4) for bboxes.
+
+ Returns:
+ Tensor: Converted bboxes.
+ """
+ x1, y1, x2, y2 = bbox.split((1, 1, 1, 1), dim=-1)
+ bbox_new = [(x1 + x2) / 2, (y1 + y2) / 2, (x2 - x1), (y2 - y1)]
+ return torch.cat(bbox_new, dim=-1)
diff --git a/mmdet/core/evaluation/__init__.py b/mmdet/core/evaluation/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11ef15b9db95166b4427ad4d08debbd0630a741
--- /dev/null
+++ b/mmdet/core/evaluation/__init__.py
@@ -0,0 +1,15 @@
+from .class_names import (cityscapes_classes, coco_classes, dataset_aliases,
+ get_classes, imagenet_det_classes,
+ imagenet_vid_classes, voc_classes)
+from .eval_hooks import DistEvalHook, EvalHook
+from .mean_ap import average_precision, eval_map, print_map_summary
+from .recall import (eval_recalls, plot_iou_recall, plot_num_recall,
+ print_recall_summary)
+
+__all__ = [
+ 'voc_classes', 'imagenet_det_classes', 'imagenet_vid_classes',
+ 'coco_classes', 'cityscapes_classes', 'dataset_aliases', 'get_classes',
+ 'DistEvalHook', 'EvalHook', 'average_precision', 'eval_map',
+ 'print_map_summary', 'eval_recalls', 'print_recall_summary',
+ 'plot_num_recall', 'plot_iou_recall'
+]
diff --git a/mmdet/core/evaluation/bbox_overlaps.py b/mmdet/core/evaluation/bbox_overlaps.py
new file mode 100644
index 0000000000000000000000000000000000000000..93559ea0f25369d552a5365312fa32b9ffec9226
--- /dev/null
+++ b/mmdet/core/evaluation/bbox_overlaps.py
@@ -0,0 +1,48 @@
+import numpy as np
+
+
+def bbox_overlaps(bboxes1, bboxes2, mode='iou', eps=1e-6):
+ """Calculate the ious between each bbox of bboxes1 and bboxes2.
+
+ Args:
+ bboxes1(ndarray): shape (n, 4)
+ bboxes2(ndarray): shape (k, 4)
+ mode(str): iou (intersection over union) or iof (intersection
+ over foreground)
+
+ Returns:
+ ious(ndarray): shape (n, k)
+ """
+
+ assert mode in ['iou', 'iof']
+
+ bboxes1 = bboxes1.astype(np.float32)
+ bboxes2 = bboxes2.astype(np.float32)
+ rows = bboxes1.shape[0]
+ cols = bboxes2.shape[0]
+ ious = np.zeros((rows, cols), dtype=np.float32)
+ if rows * cols == 0:
+ return ious
+ exchange = False
+ if bboxes1.shape[0] > bboxes2.shape[0]:
+ bboxes1, bboxes2 = bboxes2, bboxes1
+ ious = np.zeros((cols, rows), dtype=np.float32)
+ exchange = True
+ area1 = (bboxes1[:, 2] - bboxes1[:, 0]) * (bboxes1[:, 3] - bboxes1[:, 1])
+ area2 = (bboxes2[:, 2] - bboxes2[:, 0]) * (bboxes2[:, 3] - bboxes2[:, 1])
+ for i in range(bboxes1.shape[0]):
+ x_start = np.maximum(bboxes1[i, 0], bboxes2[:, 0])
+ y_start = np.maximum(bboxes1[i, 1], bboxes2[:, 1])
+ x_end = np.minimum(bboxes1[i, 2], bboxes2[:, 2])
+ y_end = np.minimum(bboxes1[i, 3], bboxes2[:, 3])
+ overlap = np.maximum(x_end - x_start, 0) * np.maximum(
+ y_end - y_start, 0)
+ if mode == 'iou':
+ union = area1[i] + area2 - overlap
+ else:
+ union = area1[i] if not exchange else area2
+ union = np.maximum(union, eps)
+ ious[i, :] = overlap / union
+ if exchange:
+ ious = ious.T
+ return ious
diff --git a/mmdet/core/evaluation/class_names.py b/mmdet/core/evaluation/class_names.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2487c2ee2d010c40db0e1c2b51c91b194e84dc7
--- /dev/null
+++ b/mmdet/core/evaluation/class_names.py
@@ -0,0 +1,116 @@
+import mmcv
+
+
+def wider_face_classes():
+ return ['face']
+
+
+def voc_classes():
+ return [
+ 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat',
+ 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person',
+ 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor'
+ ]
+
+
+def imagenet_det_classes():
+ return [
+ 'accordion', 'airplane', 'ant', 'antelope', 'apple', 'armadillo',
+ 'artichoke', 'axe', 'baby_bed', 'backpack', 'bagel', 'balance_beam',
+ 'banana', 'band_aid', 'banjo', 'baseball', 'basketball', 'bathing_cap',
+ 'beaker', 'bear', 'bee', 'bell_pepper', 'bench', 'bicycle', 'binder',
+ 'bird', 'bookshelf', 'bow_tie', 'bow', 'bowl', 'brassiere', 'burrito',
+ 'bus', 'butterfly', 'camel', 'can_opener', 'car', 'cart', 'cattle',
+ 'cello', 'centipede', 'chain_saw', 'chair', 'chime', 'cocktail_shaker',
+ 'coffee_maker', 'computer_keyboard', 'computer_mouse', 'corkscrew',
+ 'cream', 'croquet_ball', 'crutch', 'cucumber', 'cup_or_mug', 'diaper',
+ 'digital_clock', 'dishwasher', 'dog', 'domestic_cat', 'dragonfly',
+ 'drum', 'dumbbell', 'electric_fan', 'elephant', 'face_powder', 'fig',
+ 'filing_cabinet', 'flower_pot', 'flute', 'fox', 'french_horn', 'frog',
+ 'frying_pan', 'giant_panda', 'goldfish', 'golf_ball', 'golfcart',
+ 'guacamole', 'guitar', 'hair_dryer', 'hair_spray', 'hamburger',
+ 'hammer', 'hamster', 'harmonica', 'harp', 'hat_with_a_wide_brim',
+ 'head_cabbage', 'helmet', 'hippopotamus', 'horizontal_bar', 'horse',
+ 'hotdog', 'iPod', 'isopod', 'jellyfish', 'koala_bear', 'ladle',
+ 'ladybug', 'lamp', 'laptop', 'lemon', 'lion', 'lipstick', 'lizard',
+ 'lobster', 'maillot', 'maraca', 'microphone', 'microwave', 'milk_can',
+ 'miniskirt', 'monkey', 'motorcycle', 'mushroom', 'nail', 'neck_brace',
+ 'oboe', 'orange', 'otter', 'pencil_box', 'pencil_sharpener', 'perfume',
+ 'person', 'piano', 'pineapple', 'ping-pong_ball', 'pitcher', 'pizza',
+ 'plastic_bag', 'plate_rack', 'pomegranate', 'popsicle', 'porcupine',
+ 'power_drill', 'pretzel', 'printer', 'puck', 'punching_bag', 'purse',
+ 'rabbit', 'racket', 'ray', 'red_panda', 'refrigerator',
+ 'remote_control', 'rubber_eraser', 'rugby_ball', 'ruler',
+ 'salt_or_pepper_shaker', 'saxophone', 'scorpion', 'screwdriver',
+ 'seal', 'sheep', 'ski', 'skunk', 'snail', 'snake', 'snowmobile',
+ 'snowplow', 'soap_dispenser', 'soccer_ball', 'sofa', 'spatula',
+ 'squirrel', 'starfish', 'stethoscope', 'stove', 'strainer',
+ 'strawberry', 'stretcher', 'sunglasses', 'swimming_trunks', 'swine',
+ 'syringe', 'table', 'tape_player', 'tennis_ball', 'tick', 'tie',
+ 'tiger', 'toaster', 'traffic_light', 'train', 'trombone', 'trumpet',
+ 'turtle', 'tv_or_monitor', 'unicycle', 'vacuum', 'violin',
+ 'volleyball', 'waffle_iron', 'washer', 'water_bottle', 'watercraft',
+ 'whale', 'wine_bottle', 'zebra'
+ ]
+
+
+def imagenet_vid_classes():
+ return [
+ 'airplane', 'antelope', 'bear', 'bicycle', 'bird', 'bus', 'car',
+ 'cattle', 'dog', 'domestic_cat', 'elephant', 'fox', 'giant_panda',
+ 'hamster', 'horse', 'lion', 'lizard', 'monkey', 'motorcycle', 'rabbit',
+ 'red_panda', 'sheep', 'snake', 'squirrel', 'tiger', 'train', 'turtle',
+ 'watercraft', 'whale', 'zebra'
+ ]
+
+
+def coco_classes():
+ return [
+ 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus', 'train',
+ 'truck', 'boat', 'traffic_light', 'fire_hydrant', 'stop_sign',
+ 'parking_meter', 'bench', 'bird', 'cat', 'dog', 'horse', 'sheep',
+ 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'backpack', 'umbrella',
+ 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
+ 'sports_ball', 'kite', 'baseball_bat', 'baseball_glove', 'skateboard',
+ 'surfboard', 'tennis_racket', 'bottle', 'wine_glass', 'cup', 'fork',
+ 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich', 'orange',
+ 'broccoli', 'carrot', 'hot_dog', 'pizza', 'donut', 'cake', 'chair',
+ 'couch', 'potted_plant', 'bed', 'dining_table', 'toilet', 'tv',
+ 'laptop', 'mouse', 'remote', 'keyboard', 'cell_phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock', 'vase',
+ 'scissors', 'teddy_bear', 'hair_drier', 'toothbrush'
+ ]
+
+
+def cityscapes_classes():
+ return [
+ 'person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle'
+ ]
+
+
+dataset_aliases = {
+ 'voc': ['voc', 'pascal_voc', 'voc07', 'voc12'],
+ 'imagenet_det': ['det', 'imagenet_det', 'ilsvrc_det'],
+ 'imagenet_vid': ['vid', 'imagenet_vid', 'ilsvrc_vid'],
+ 'coco': ['coco', 'mscoco', 'ms_coco'],
+ 'wider_face': ['WIDERFaceDataset', 'wider_face', 'WIDERFace'],
+ 'cityscapes': ['cityscapes']
+}
+
+
+def get_classes(dataset):
+ """Get class names of a dataset."""
+ alias2name = {}
+ for name, aliases in dataset_aliases.items():
+ for alias in aliases:
+ alias2name[alias] = name
+
+ if mmcv.is_str(dataset):
+ if dataset in alias2name:
+ labels = eval(alias2name[dataset] + '_classes()')
+ else:
+ raise ValueError(f'Unrecognized dataset: {dataset}')
+ else:
+ raise TypeError(f'dataset must a str, but got {type(dataset)}')
+ return labels
diff --git a/mmdet/core/evaluation/eval_hooks.py b/mmdet/core/evaluation/eval_hooks.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fb932eae1ccb23a2b687a05a6cb9525200de718
--- /dev/null
+++ b/mmdet/core/evaluation/eval_hooks.py
@@ -0,0 +1,303 @@
+import os.path as osp
+import warnings
+from math import inf
+
+import mmcv
+import torch.distributed as dist
+from mmcv.runner import Hook
+from torch.nn.modules.batchnorm import _BatchNorm
+from torch.utils.data import DataLoader
+
+from mmdet.utils import get_root_logger
+
+
+class EvalHook(Hook):
+ """Evaluation hook.
+
+ Notes:
+ If new arguments are added for EvalHook, tools/test.py,
+ tools/analysis_tools/eval_metric.py may be effected.
+
+ Attributes:
+ dataloader (DataLoader): A PyTorch dataloader.
+ start (int, optional): Evaluation starting epoch. It enables evaluation
+ before the training starts if ``start`` <= the resuming epoch.
+ If None, whether to evaluate is merely decided by ``interval``.
+ Default: None.
+ interval (int): Evaluation interval (by epochs). Default: 1.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be save in best.json.
+ Options are the evaluation metrics to the test dataset. e.g.,
+ ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
+ segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
+ ``auto``, the first key will be used. The interval of
+ ``CheckpointHook`` should device EvalHook. Default: None.
+ rule (str, optional): Comparison rule for best score. If set to None,
+ it will infer a reasonable rule. Keys such as 'mAP' or 'AR' will
+ be inferred by 'greater' rule. Keys contain 'loss' will be inferred
+ by 'less' rule. Options are 'greater', 'less'. Default: None.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+
+ rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y}
+ init_value_map = {'greater': -inf, 'less': inf}
+ greater_keys = ['mAP', 'AR']
+ less_keys = ['loss']
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ save_best=None,
+ rule=None,
+ **eval_kwargs):
+ if not isinstance(dataloader, DataLoader):
+ raise TypeError('dataloader must be a pytorch DataLoader, but got'
+ f' {type(dataloader)}')
+ if not interval > 0:
+ raise ValueError(f'interval must be positive, but got {interval}')
+ if start is not None and start < 0:
+ warnings.warn(
+ f'The evaluation start epoch {start} is smaller than 0, '
+ f'use 0 instead', UserWarning)
+ start = 0
+ self.dataloader = dataloader
+ self.interval = interval
+ self.by_epoch = by_epoch
+ self.start = start
+ assert isinstance(save_best, str) or save_best is None
+ self.save_best = save_best
+ self.eval_kwargs = eval_kwargs
+ self.initial_epoch_flag = True
+
+ self.logger = get_root_logger()
+
+ if self.save_best is not None:
+ self._init_rule(rule, self.save_best)
+
+ def _init_rule(self, rule, key_indicator):
+ """Initialize rule, key_indicator, comparison_func, and best score.
+
+ Args:
+ rule (str | None): Comparison rule for best score.
+ key_indicator (str | None): Key indicator to determine the
+ comparison rule.
+ """
+ if rule not in self.rule_map and rule is not None:
+ raise KeyError(f'rule must be greater, less or None, '
+ f'but got {rule}.')
+
+ if rule is None:
+ if key_indicator != 'auto':
+ if any(key in key_indicator for key in self.greater_keys):
+ rule = 'greater'
+ elif any(key in key_indicator for key in self.less_keys):
+ rule = 'less'
+ else:
+ raise ValueError(f'Cannot infer the rule for key '
+ f'{key_indicator}, thus a specific rule '
+ f'must be specified.')
+ self.rule = rule
+ self.key_indicator = key_indicator
+ if self.rule is not None:
+ self.compare_func = self.rule_map[self.rule]
+
+ def before_run(self, runner):
+ if self.save_best is not None:
+ if runner.meta is None:
+ warnings.warn('runner.meta is None. Creating a empty one.')
+ runner.meta = dict()
+ runner.meta.setdefault('hook_msgs', dict())
+
+ def before_train_epoch(self, runner):
+ """Evaluate the model only at the start of training."""
+ if not self.initial_epoch_flag:
+ return
+ if self.start is not None and runner.epoch >= self.start:
+ self.after_train_epoch(runner)
+ self.initial_epoch_flag = False
+
+ def evaluation_flag(self, runner):
+ """Judge whether to perform_evaluation after this epoch.
+
+ Returns:
+ bool: The flag indicating whether to perform evaluation.
+ """
+ if self.start is None:
+ if not self.every_n_epochs(runner, self.interval):
+ # No evaluation during the interval epochs.
+ return False
+ elif (runner.epoch + 1) < self.start:
+ # No evaluation if start is larger than the current epoch.
+ return False
+ else:
+ # Evaluation only at epochs 3, 5, 7... if start==3 and interval==2
+ if (runner.epoch + 1 - self.start) % self.interval:
+ return False
+ return True
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch or not self.evaluation_flag(runner):
+ return
+ from mmdet.apis import single_gpu_test
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ key_score = self.evaluate(runner, results)
+ if self.save_best:
+ self.save_best_checkpoint(runner, key_score)
+
+ def after_train_iter(self, runner):
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+ from mmdet.apis import single_gpu_test
+ results = single_gpu_test(runner.model, self.dataloader, show=False)
+ key_score = self.evaluate(runner, results)
+ if self.save_best:
+ self.save_best_checkpoint(runner, key_score)
+
+ def save_best_checkpoint(self, runner, key_score):
+ best_score = runner.meta['hook_msgs'].get(
+ 'best_score', self.init_value_map[self.rule])
+ if self.compare_func(key_score, best_score):
+ best_score = key_score
+ runner.meta['hook_msgs']['best_score'] = best_score
+ last_ckpt = runner.meta['hook_msgs']['last_ckpt']
+ runner.meta['hook_msgs']['best_ckpt'] = last_ckpt
+ mmcv.symlink(
+ last_ckpt,
+ osp.join(runner.work_dir, f'best_{self.key_indicator}.pth'))
+ time_stamp = runner.epoch + 1 if self.by_epoch else runner.iter + 1
+ self.logger.info(f'Now best checkpoint is epoch_{time_stamp}.pth.'
+ f'Best {self.key_indicator} is {best_score:0.4f}')
+
+ def evaluate(self, runner, results):
+ eval_res = self.dataloader.dataset.evaluate(
+ results, logger=runner.logger, **self.eval_kwargs)
+ for name, val in eval_res.items():
+ runner.log_buffer.output[name] = val
+ runner.log_buffer.ready = True
+ if self.save_best is not None:
+ if self.key_indicator == 'auto':
+ # infer from eval_results
+ self._init_rule(self.rule, list(eval_res.keys())[0])
+ return eval_res[self.key_indicator]
+ else:
+ return None
+
+
+class DistEvalHook(EvalHook):
+ """Distributed evaluation hook.
+
+ Notes:
+ If new arguments are added, tools/test.py may be effected.
+
+ Attributes:
+ dataloader (DataLoader): A PyTorch dataloader.
+ start (int, optional): Evaluation starting epoch. It enables evaluation
+ before the training starts if ``start`` <= the resuming epoch.
+ If None, whether to evaluate is merely decided by ``interval``.
+ Default: None.
+ interval (int): Evaluation interval (by epochs). Default: 1.
+ tmpdir (str | None): Temporary directory to save the results of all
+ processes. Default: None.
+ gpu_collect (bool): Whether to use gpu or cpu to collect results.
+ Default: False.
+ save_best (str, optional): If a metric is specified, it would measure
+ the best checkpoint during evaluation. The information about best
+ checkpoint would be save in best.json.
+ Options are the evaluation metrics to the test dataset. e.g.,
+ ``bbox_mAP``, ``segm_mAP`` for bbox detection and instance
+ segmentation. ``AR@100`` for proposal recall. If ``save_best`` is
+ ``auto``, the first key will be used. The interval of
+ ``CheckpointHook`` should device EvalHook. Default: None.
+ rule (str | None): Comparison rule for best score. If set to None,
+ it will infer a reasonable rule. Default: 'None'.
+ broadcast_bn_buffer (bool): Whether to broadcast the
+ buffer(running_mean and running_var) of rank 0 to other rank
+ before evaluation. Default: True.
+ **eval_kwargs: Evaluation arguments fed into the evaluate function of
+ the dataset.
+ """
+
+ def __init__(self,
+ dataloader,
+ start=None,
+ interval=1,
+ by_epoch=True,
+ tmpdir=None,
+ gpu_collect=False,
+ save_best=None,
+ rule=None,
+ broadcast_bn_buffer=True,
+ **eval_kwargs):
+ super().__init__(
+ dataloader,
+ start=start,
+ interval=interval,
+ by_epoch=by_epoch,
+ save_best=save_best,
+ rule=rule,
+ **eval_kwargs)
+ self.broadcast_bn_buffer = broadcast_bn_buffer
+ self.tmpdir = tmpdir
+ self.gpu_collect = gpu_collect
+
+ def _broadcast_bn_buffer(self, runner):
+ # Synchronization of BatchNorm's buffer (running_mean
+ # and running_var) is not supported in the DDP of pytorch,
+ # which may cause the inconsistent performance of models in
+ # different ranks, so we broadcast BatchNorm's buffers
+ # of rank 0 to other ranks to avoid this.
+ if self.broadcast_bn_buffer:
+ model = runner.model
+ for name, module in model.named_modules():
+ if isinstance(module,
+ _BatchNorm) and module.track_running_stats:
+ dist.broadcast(module.running_var, 0)
+ dist.broadcast(module.running_mean, 0)
+
+ def after_train_epoch(self, runner):
+ if not self.by_epoch or not self.evaluation_flag(runner):
+ return
+
+ if self.broadcast_bn_buffer:
+ self._broadcast_bn_buffer(runner)
+
+ from mmdet.apis import multi_gpu_test
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ key_score = self.evaluate(runner, results)
+ if self.save_best:
+ self.save_best_checkpoint(runner, key_score)
+
+ def after_train_iter(self, runner):
+ if self.by_epoch or not self.every_n_iters(runner, self.interval):
+ return
+
+ if self.broadcast_bn_buffer:
+ self._broadcast_bn_buffer(runner)
+
+ from mmdet.apis import multi_gpu_test
+ tmpdir = self.tmpdir
+ if tmpdir is None:
+ tmpdir = osp.join(runner.work_dir, '.eval_hook')
+ results = multi_gpu_test(
+ runner.model,
+ self.dataloader,
+ tmpdir=tmpdir,
+ gpu_collect=self.gpu_collect)
+ if runner.rank == 0:
+ print('\n')
+ key_score = self.evaluate(runner, results)
+ if self.save_best:
+ self.save_best_checkpoint(runner, key_score)
diff --git a/mmdet/core/evaluation/mean_ap.py b/mmdet/core/evaluation/mean_ap.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d653a35497f6a0135c4374a09eb7c11399e3244
--- /dev/null
+++ b/mmdet/core/evaluation/mean_ap.py
@@ -0,0 +1,469 @@
+from multiprocessing import Pool
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .bbox_overlaps import bbox_overlaps
+from .class_names import get_classes
+
+
+def average_precision(recalls, precisions, mode='area'):
+ """Calculate average precision (for single or multiple scales).
+
+ Args:
+ recalls (ndarray): shape (num_scales, num_dets) or (num_dets, )
+ precisions (ndarray): shape (num_scales, num_dets) or (num_dets, )
+ mode (str): 'area' or '11points', 'area' means calculating the area
+ under precision-recall curve, '11points' means calculating
+ the average precision of recalls at [0, 0.1, ..., 1]
+
+ Returns:
+ float or ndarray: calculated average precision
+ """
+ no_scale = False
+ if recalls.ndim == 1:
+ no_scale = True
+ recalls = recalls[np.newaxis, :]
+ precisions = precisions[np.newaxis, :]
+ assert recalls.shape == precisions.shape and recalls.ndim == 2
+ num_scales = recalls.shape[0]
+ ap = np.zeros(num_scales, dtype=np.float32)
+ if mode == 'area':
+ zeros = np.zeros((num_scales, 1), dtype=recalls.dtype)
+ ones = np.ones((num_scales, 1), dtype=recalls.dtype)
+ mrec = np.hstack((zeros, recalls, ones))
+ mpre = np.hstack((zeros, precisions, zeros))
+ for i in range(mpre.shape[1] - 1, 0, -1):
+ mpre[:, i - 1] = np.maximum(mpre[:, i - 1], mpre[:, i])
+ for i in range(num_scales):
+ ind = np.where(mrec[i, 1:] != mrec[i, :-1])[0]
+ ap[i] = np.sum(
+ (mrec[i, ind + 1] - mrec[i, ind]) * mpre[i, ind + 1])
+ elif mode == '11points':
+ for i in range(num_scales):
+ for thr in np.arange(0, 1 + 1e-3, 0.1):
+ precs = precisions[i, recalls[i, :] >= thr]
+ prec = precs.max() if precs.size > 0 else 0
+ ap[i] += prec
+ ap /= 11
+ else:
+ raise ValueError(
+ 'Unrecognized mode, only "area" and "11points" are supported')
+ if no_scale:
+ ap = ap[0]
+ return ap
+
+
+def tpfp_imagenet(det_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ default_iou_thr=0.5,
+ area_ranges=None):
+ """Check if detected bboxes are true positive or false positive.
+
+ Args:
+ det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+ gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+ gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+ of shape (k, 4). Default: None
+ default_iou_thr (float): IoU threshold to be considered as matched for
+ medium and large bboxes (small ones have special rules).
+ Default: 0.5.
+ area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
+ in the format [(min1, max1), (min2, max2), ...]. Default: None.
+
+ Returns:
+ tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+ each array is (num_scales, m).
+ """
+ # an indicator of ignored gts
+ gt_ignore_inds = np.concatenate(
+ (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
+ np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
+ # stack gt_bboxes and gt_bboxes_ignore for convenience
+ gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
+ num_dets = det_bboxes.shape[0]
+ num_gts = gt_bboxes.shape[0]
+ if area_ranges is None:
+ area_ranges = [(None, None)]
+ num_scales = len(area_ranges)
+ # tp and fp are of shape (num_scales, num_gts), each row is tp or fp
+ # of a certain scale.
+ tp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ if gt_bboxes.shape[0] == 0:
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1])
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+ return tp, fp
+ ious = bbox_overlaps(det_bboxes, gt_bboxes - 1)
+ gt_w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
+ gt_h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
+ iou_thrs = np.minimum((gt_w * gt_h) / ((gt_w + 10.0) * (gt_h + 10.0)),
+ default_iou_thr)
+ # sort all detections by scores in descending order
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = gt_w * gt_h
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ max_iou = -1
+ matched_gt = -1
+ # find best overlapped available gt
+ for j in range(num_gts):
+ # different from PASCAL VOC: allow finding other gts if the
+ # best overlapped ones are already matched by other det bboxes
+ if gt_covered[j]:
+ continue
+ elif ious[i, j] >= iou_thrs[j] and ious[i, j] > max_iou:
+ max_iou = ious[i, j]
+ matched_gt = j
+ # there are 4 cases for a det bbox:
+ # 1. it matches a gt, tp = 1, fp = 0
+ # 2. it matches an ignored gt, tp = 0, fp = 0
+ # 3. it matches no gt and within area range, tp = 0, fp = 1
+ # 4. it matches no gt but is beyond area range, tp = 0, fp = 0
+ if matched_gt >= 0:
+ gt_covered[matched_gt] = 1
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ tp[k, i] = 1
+ elif min_area is None:
+ fp[k, i] = 1
+ else:
+ bbox = det_bboxes[i, :4]
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+ if area >= min_area and area < max_area:
+ fp[k, i] = 1
+ return tp, fp
+
+
+def tpfp_default(det_bboxes,
+ gt_bboxes,
+ gt_bboxes_ignore=None,
+ iou_thr=0.5,
+ area_ranges=None):
+ """Check if detected bboxes are true positive or false positive.
+
+ Args:
+ det_bbox (ndarray): Detected bboxes of this image, of shape (m, 5).
+ gt_bboxes (ndarray): GT bboxes of this image, of shape (n, 4).
+ gt_bboxes_ignore (ndarray): Ignored gt bboxes of this image,
+ of shape (k, 4). Default: None
+ iou_thr (float): IoU threshold to be considered as matched.
+ Default: 0.5.
+ area_ranges (list[tuple] | None): Range of bbox areas to be evaluated,
+ in the format [(min1, max1), (min2, max2), ...]. Default: None.
+
+ Returns:
+ tuple[np.ndarray]: (tp, fp) whose elements are 0 and 1. The shape of
+ each array is (num_scales, m).
+ """
+ # an indicator of ignored gts
+ gt_ignore_inds = np.concatenate(
+ (np.zeros(gt_bboxes.shape[0], dtype=np.bool),
+ np.ones(gt_bboxes_ignore.shape[0], dtype=np.bool)))
+ # stack gt_bboxes and gt_bboxes_ignore for convenience
+ gt_bboxes = np.vstack((gt_bboxes, gt_bboxes_ignore))
+
+ num_dets = det_bboxes.shape[0]
+ num_gts = gt_bboxes.shape[0]
+ if area_ranges is None:
+ area_ranges = [(None, None)]
+ num_scales = len(area_ranges)
+ # tp and fp are of shape (num_scales, num_gts), each row is tp or fp of
+ # a certain scale
+ tp = np.zeros((num_scales, num_dets), dtype=np.float32)
+ fp = np.zeros((num_scales, num_dets), dtype=np.float32)
+
+ # if there is no gt bboxes in this image, then all det bboxes
+ # within area range are false positives
+ if gt_bboxes.shape[0] == 0:
+ if area_ranges == [(None, None)]:
+ fp[...] = 1
+ else:
+ det_areas = (det_bboxes[:, 2] - det_bboxes[:, 0]) * (
+ det_bboxes[:, 3] - det_bboxes[:, 1])
+ for i, (min_area, max_area) in enumerate(area_ranges):
+ fp[i, (det_areas >= min_area) & (det_areas < max_area)] = 1
+ return tp, fp
+
+ ious = bbox_overlaps(det_bboxes, gt_bboxes)
+ # for each det, the max iou with all gts
+ ious_max = ious.max(axis=1)
+ # for each det, which gt overlaps most with it
+ ious_argmax = ious.argmax(axis=1)
+ # sort all dets in descending order by scores
+ sort_inds = np.argsort(-det_bboxes[:, -1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ gt_covered = np.zeros(num_gts, dtype=bool)
+ # if no area range is specified, gt_area_ignore is all False
+ if min_area is None:
+ gt_area_ignore = np.zeros_like(gt_ignore_inds, dtype=bool)
+ else:
+ gt_areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ gt_area_ignore = (gt_areas < min_area) | (gt_areas >= max_area)
+ for i in sort_inds:
+ if ious_max[i] >= iou_thr:
+ matched_gt = ious_argmax[i]
+ if not (gt_ignore_inds[matched_gt]
+ or gt_area_ignore[matched_gt]):
+ if not gt_covered[matched_gt]:
+ gt_covered[matched_gt] = True
+ tp[k, i] = 1
+ else:
+ fp[k, i] = 1
+ # otherwise ignore this detected bbox, tp = 0, fp = 0
+ elif min_area is None:
+ fp[k, i] = 1
+ else:
+ bbox = det_bboxes[i, :4]
+ area = (bbox[2] - bbox[0]) * (bbox[3] - bbox[1])
+ if area >= min_area and area < max_area:
+ fp[k, i] = 1
+ return tp, fp
+
+
+def get_cls_results(det_results, annotations, class_id):
+ """Get det results and gt information of a certain class.
+
+ Args:
+ det_results (list[list]): Same as `eval_map()`.
+ annotations (list[dict]): Same as `eval_map()`.
+ class_id (int): ID of a specific class.
+
+ Returns:
+ tuple[list[np.ndarray]]: detected bboxes, gt bboxes, ignored gt bboxes
+ """
+ cls_dets = [img_res[class_id] for img_res in det_results]
+ cls_gts = []
+ cls_gts_ignore = []
+ for ann in annotations:
+ gt_inds = ann['labels'] == class_id
+ cls_gts.append(ann['bboxes'][gt_inds, :])
+
+ if ann.get('labels_ignore', None) is not None:
+ ignore_inds = ann['labels_ignore'] == class_id
+ cls_gts_ignore.append(ann['bboxes_ignore'][ignore_inds, :])
+ else:
+ cls_gts_ignore.append(np.empty((0, 4), dtype=np.float32))
+
+ return cls_dets, cls_gts, cls_gts_ignore
+
+
+def eval_map(det_results,
+ annotations,
+ scale_ranges=None,
+ iou_thr=0.5,
+ dataset=None,
+ logger=None,
+ tpfp_fn=None,
+ nproc=4):
+ """Evaluate mAP of a dataset.
+
+ Args:
+ det_results (list[list]): [[cls1_det, cls2_det, ...], ...].
+ The outer list indicates images, and the inner list indicates
+ per-class detected bboxes.
+ annotations (list[dict]): Ground truth annotations where each item of
+ the list indicates an image. Keys of annotations are:
+
+ - `bboxes`: numpy array of shape (n, 4)
+ - `labels`: numpy array of shape (n, )
+ - `bboxes_ignore` (optional): numpy array of shape (k, 4)
+ - `labels_ignore` (optional): numpy array of shape (k, )
+ scale_ranges (list[tuple] | None): Range of scales to be evaluated,
+ in the format [(min1, max1), (min2, max2), ...]. A range of
+ (32, 64) means the area range between (32**2, 64**2).
+ Default: None.
+ iou_thr (float): IoU threshold to be considered as matched.
+ Default: 0.5.
+ dataset (list[str] | str | None): Dataset name or dataset classes,
+ there are minor differences in metrics for different datsets, e.g.
+ "voc07", "imagenet_det", etc. Default: None.
+ logger (logging.Logger | str | None): The way to print the mAP
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ tpfp_fn (callable | None): The function used to determine true/
+ false positives. If None, :func:`tpfp_default` is used as default
+ unless dataset is 'det' or 'vid' (:func:`tpfp_imagenet` in this
+ case). If it is given as a function, then this function is used
+ to evaluate tp & fp. Default None.
+ nproc (int): Processes used for computing TP and FP.
+ Default: 4.
+
+ Returns:
+ tuple: (mAP, [dict, dict, ...])
+ """
+ assert len(det_results) == len(annotations)
+
+ num_imgs = len(det_results)
+ num_scales = len(scale_ranges) if scale_ranges is not None else 1
+ num_classes = len(det_results[0]) # positive class num
+ area_ranges = ([(rg[0]**2, rg[1]**2) for rg in scale_ranges]
+ if scale_ranges is not None else None)
+
+ pool = Pool(nproc)
+ eval_results = []
+ for i in range(num_classes):
+ # get gt and det bboxes of this class
+ cls_dets, cls_gts, cls_gts_ignore = get_cls_results(
+ det_results, annotations, i)
+ # choose proper function according to datasets to compute tp and fp
+ if tpfp_fn is None:
+ if dataset in ['det', 'vid']:
+ tpfp_fn = tpfp_imagenet
+ else:
+ tpfp_fn = tpfp_default
+ if not callable(tpfp_fn):
+ raise ValueError(
+ f'tpfp_fn has to be a function or None, but got {tpfp_fn}')
+
+ # compute tp and fp for each image with multiple processes
+ tpfp = pool.starmap(
+ tpfp_fn,
+ zip(cls_dets, cls_gts, cls_gts_ignore,
+ [iou_thr for _ in range(num_imgs)],
+ [area_ranges for _ in range(num_imgs)]))
+ tp, fp = tuple(zip(*tpfp))
+ # calculate gt number of each scale
+ # ignored gts or gts beyond the specific scale are not counted
+ num_gts = np.zeros(num_scales, dtype=int)
+ for j, bbox in enumerate(cls_gts):
+ if area_ranges is None:
+ num_gts[0] += bbox.shape[0]
+ else:
+ gt_areas = (bbox[:, 2] - bbox[:, 0]) * (
+ bbox[:, 3] - bbox[:, 1])
+ for k, (min_area, max_area) in enumerate(area_ranges):
+ num_gts[k] += np.sum((gt_areas >= min_area)
+ & (gt_areas < max_area))
+ # sort all det bboxes by score, also sort tp and fp
+ cls_dets = np.vstack(cls_dets)
+ num_dets = cls_dets.shape[0]
+ sort_inds = np.argsort(-cls_dets[:, -1])
+ tp = np.hstack(tp)[:, sort_inds]
+ fp = np.hstack(fp)[:, sort_inds]
+ # calculate recall and precision with tp and fp
+ tp = np.cumsum(tp, axis=1)
+ fp = np.cumsum(fp, axis=1)
+ eps = np.finfo(np.float32).eps
+ recalls = tp / np.maximum(num_gts[:, np.newaxis], eps)
+ precisions = tp / np.maximum((tp + fp), eps)
+ # calculate AP
+ if scale_ranges is None:
+ recalls = recalls[0, :]
+ precisions = precisions[0, :]
+ num_gts = num_gts.item()
+ mode = 'area' if dataset != 'voc07' else '11points'
+ ap = average_precision(recalls, precisions, mode)
+ eval_results.append({
+ 'num_gts': num_gts,
+ 'num_dets': num_dets,
+ 'recall': recalls,
+ 'precision': precisions,
+ 'ap': ap
+ })
+ pool.close()
+ if scale_ranges is not None:
+ # shape (num_classes, num_scales)
+ all_ap = np.vstack([cls_result['ap'] for cls_result in eval_results])
+ all_num_gts = np.vstack(
+ [cls_result['num_gts'] for cls_result in eval_results])
+ mean_ap = []
+ for i in range(num_scales):
+ if np.any(all_num_gts[:, i] > 0):
+ mean_ap.append(all_ap[all_num_gts[:, i] > 0, i].mean())
+ else:
+ mean_ap.append(0.0)
+ else:
+ aps = []
+ for cls_result in eval_results:
+ if cls_result['num_gts'] > 0:
+ aps.append(cls_result['ap'])
+ mean_ap = np.array(aps).mean().item() if aps else 0.0
+
+ print_map_summary(
+ mean_ap, eval_results, dataset, area_ranges, logger=logger)
+
+ return mean_ap, eval_results
+
+
+def print_map_summary(mean_ap,
+ results,
+ dataset=None,
+ scale_ranges=None,
+ logger=None):
+ """Print mAP and results of each class.
+
+ A table will be printed to show the gts/dets/recall/AP of each class and
+ the mAP.
+
+ Args:
+ mean_ap (float): Calculated from `eval_map()`.
+ results (list[dict]): Calculated from `eval_map()`.
+ dataset (list[str] | str | None): Dataset name or dataset classes.
+ scale_ranges (list[tuple] | None): Range of scales to be evaluated.
+ logger (logging.Logger | str | None): The way to print the mAP
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ """
+
+ if logger == 'silent':
+ return
+
+ if isinstance(results[0]['ap'], np.ndarray):
+ num_scales = len(results[0]['ap'])
+ else:
+ num_scales = 1
+
+ if scale_ranges is not None:
+ assert len(scale_ranges) == num_scales
+
+ num_classes = len(results)
+
+ recalls = np.zeros((num_scales, num_classes), dtype=np.float32)
+ aps = np.zeros((num_scales, num_classes), dtype=np.float32)
+ num_gts = np.zeros((num_scales, num_classes), dtype=int)
+ for i, cls_result in enumerate(results):
+ if cls_result['recall'].size > 0:
+ recalls[:, i] = np.array(cls_result['recall'], ndmin=2)[:, -1]
+ aps[:, i] = cls_result['ap']
+ num_gts[:, i] = cls_result['num_gts']
+
+ if dataset is None:
+ label_names = [str(i) for i in range(num_classes)]
+ elif mmcv.is_str(dataset):
+ label_names = get_classes(dataset)
+ else:
+ label_names = dataset
+
+ if not isinstance(mean_ap, list):
+ mean_ap = [mean_ap]
+
+ header = ['class', 'gts', 'dets', 'recall', 'ap']
+ for i in range(num_scales):
+ if scale_ranges is not None:
+ print_log(f'Scale range {scale_ranges[i]}', logger=logger)
+ table_data = [header]
+ for j in range(num_classes):
+ row_data = [
+ label_names[j], num_gts[i, j], results[j]['num_dets'],
+ f'{recalls[i, j]:.3f}', f'{aps[i, j]:.3f}'
+ ]
+ table_data.append(row_data)
+ table_data.append(['mAP', '', '', '', f'{mean_ap[i]:.3f}'])
+ table = AsciiTable(table_data)
+ table.inner_footing_row_border = True
+ print_log('\n' + table.table, logger=logger)
diff --git a/mmdet/core/evaluation/recall.py b/mmdet/core/evaluation/recall.py
new file mode 100644
index 0000000000000000000000000000000000000000..23ec744f552db1a4a76bfa63b7cc8b357deb3140
--- /dev/null
+++ b/mmdet/core/evaluation/recall.py
@@ -0,0 +1,189 @@
+from collections.abc import Sequence
+
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .bbox_overlaps import bbox_overlaps
+
+
+def _recalls(all_ious, proposal_nums, thrs):
+
+ img_num = all_ious.shape[0]
+ total_gt_num = sum([ious.shape[0] for ious in all_ious])
+
+ _ious = np.zeros((proposal_nums.size, total_gt_num), dtype=np.float32)
+ for k, proposal_num in enumerate(proposal_nums):
+ tmp_ious = np.zeros(0)
+ for i in range(img_num):
+ ious = all_ious[i][:, :proposal_num].copy()
+ gt_ious = np.zeros((ious.shape[0]))
+ if ious.size == 0:
+ tmp_ious = np.hstack((tmp_ious, gt_ious))
+ continue
+ for j in range(ious.shape[0]):
+ gt_max_overlaps = ious.argmax(axis=1)
+ max_ious = ious[np.arange(0, ious.shape[0]), gt_max_overlaps]
+ gt_idx = max_ious.argmax()
+ gt_ious[j] = max_ious[gt_idx]
+ box_idx = gt_max_overlaps[gt_idx]
+ ious[gt_idx, :] = -1
+ ious[:, box_idx] = -1
+ tmp_ious = np.hstack((tmp_ious, gt_ious))
+ _ious[k, :] = tmp_ious
+
+ _ious = np.fliplr(np.sort(_ious, axis=1))
+ recalls = np.zeros((proposal_nums.size, thrs.size))
+ for i, thr in enumerate(thrs):
+ recalls[:, i] = (_ious >= thr).sum(axis=1) / float(total_gt_num)
+
+ return recalls
+
+
+def set_recall_param(proposal_nums, iou_thrs):
+ """Check proposal_nums and iou_thrs and set correct format."""
+ if isinstance(proposal_nums, Sequence):
+ _proposal_nums = np.array(proposal_nums)
+ elif isinstance(proposal_nums, int):
+ _proposal_nums = np.array([proposal_nums])
+ else:
+ _proposal_nums = proposal_nums
+
+ if iou_thrs is None:
+ _iou_thrs = np.array([0.5])
+ elif isinstance(iou_thrs, Sequence):
+ _iou_thrs = np.array(iou_thrs)
+ elif isinstance(iou_thrs, float):
+ _iou_thrs = np.array([iou_thrs])
+ else:
+ _iou_thrs = iou_thrs
+
+ return _proposal_nums, _iou_thrs
+
+
+def eval_recalls(gts,
+ proposals,
+ proposal_nums=None,
+ iou_thrs=0.5,
+ logger=None):
+ """Calculate recalls.
+
+ Args:
+ gts (list[ndarray]): a list of arrays of shape (n, 4)
+ proposals (list[ndarray]): a list of arrays of shape (k, 4) or (k, 5)
+ proposal_nums (int | Sequence[int]): Top N proposals to be evaluated.
+ iou_thrs (float | Sequence[float]): IoU thresholds. Default: 0.5.
+ logger (logging.Logger | str | None): The way to print the recall
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+
+ Returns:
+ ndarray: recalls of different ious and proposal nums
+ """
+
+ img_num = len(gts)
+ assert img_num == len(proposals)
+
+ proposal_nums, iou_thrs = set_recall_param(proposal_nums, iou_thrs)
+
+ all_ious = []
+ for i in range(img_num):
+ if proposals[i].ndim == 2 and proposals[i].shape[1] == 5:
+ scores = proposals[i][:, 4]
+ sort_idx = np.argsort(scores)[::-1]
+ img_proposal = proposals[i][sort_idx, :]
+ else:
+ img_proposal = proposals[i]
+ prop_num = min(img_proposal.shape[0], proposal_nums[-1])
+ if gts[i] is None or gts[i].shape[0] == 0:
+ ious = np.zeros((0, img_proposal.shape[0]), dtype=np.float32)
+ else:
+ ious = bbox_overlaps(gts[i], img_proposal[:prop_num, :4])
+ all_ious.append(ious)
+ all_ious = np.array(all_ious)
+ recalls = _recalls(all_ious, proposal_nums, iou_thrs)
+
+ print_recall_summary(recalls, proposal_nums, iou_thrs, logger=logger)
+ return recalls
+
+
+def print_recall_summary(recalls,
+ proposal_nums,
+ iou_thrs,
+ row_idxs=None,
+ col_idxs=None,
+ logger=None):
+ """Print recalls in a table.
+
+ Args:
+ recalls (ndarray): calculated from `bbox_recalls`
+ proposal_nums (ndarray or list): top N proposals
+ iou_thrs (ndarray or list): iou thresholds
+ row_idxs (ndarray): which rows(proposal nums) to print
+ col_idxs (ndarray): which cols(iou thresholds) to print
+ logger (logging.Logger | str | None): The way to print the recall
+ summary. See `mmcv.utils.print_log()` for details. Default: None.
+ """
+ proposal_nums = np.array(proposal_nums, dtype=np.int32)
+ iou_thrs = np.array(iou_thrs)
+ if row_idxs is None:
+ row_idxs = np.arange(proposal_nums.size)
+ if col_idxs is None:
+ col_idxs = np.arange(iou_thrs.size)
+ row_header = [''] + iou_thrs[col_idxs].tolist()
+ table_data = [row_header]
+ for i, num in enumerate(proposal_nums[row_idxs]):
+ row = [f'{val:.3f}' for val in recalls[row_idxs[i], col_idxs].tolist()]
+ row.insert(0, num)
+ table_data.append(row)
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+
+def plot_num_recall(recalls, proposal_nums):
+ """Plot Proposal_num-Recalls curve.
+
+ Args:
+ recalls(ndarray or list): shape (k,)
+ proposal_nums(ndarray or list): same shape as `recalls`
+ """
+ if isinstance(proposal_nums, np.ndarray):
+ _proposal_nums = proposal_nums.tolist()
+ else:
+ _proposal_nums = proposal_nums
+ if isinstance(recalls, np.ndarray):
+ _recalls = recalls.tolist()
+ else:
+ _recalls = recalls
+
+ import matplotlib.pyplot as plt
+ f = plt.figure()
+ plt.plot([0] + _proposal_nums, [0] + _recalls)
+ plt.xlabel('Proposal num')
+ plt.ylabel('Recall')
+ plt.axis([0, proposal_nums.max(), 0, 1])
+ f.show()
+
+
+def plot_iou_recall(recalls, iou_thrs):
+ """Plot IoU-Recalls curve.
+
+ Args:
+ recalls(ndarray or list): shape (k,)
+ iou_thrs(ndarray or list): same shape as `recalls`
+ """
+ if isinstance(iou_thrs, np.ndarray):
+ _iou_thrs = iou_thrs.tolist()
+ else:
+ _iou_thrs = iou_thrs
+ if isinstance(recalls, np.ndarray):
+ _recalls = recalls.tolist()
+ else:
+ _recalls = recalls
+
+ import matplotlib.pyplot as plt
+ f = plt.figure()
+ plt.plot(_iou_thrs + [1.0], _recalls + [0.])
+ plt.xlabel('IoU')
+ plt.ylabel('Recall')
+ plt.axis([iou_thrs.min(), 1, 0, 1])
+ f.show()
diff --git a/mmdet/core/export/__init__.py b/mmdet/core/export/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..76589b1f279a71a59a5515d1b78cea0865f83131
--- /dev/null
+++ b/mmdet/core/export/__init__.py
@@ -0,0 +1,8 @@
+from .pytorch2onnx import (build_model_from_cfg,
+ generate_inputs_and_wrap_model,
+ preprocess_example_input)
+
+__all__ = [
+ 'build_model_from_cfg', 'generate_inputs_and_wrap_model',
+ 'preprocess_example_input'
+]
diff --git a/mmdet/core/export/pytorch2onnx.py b/mmdet/core/export/pytorch2onnx.py
new file mode 100644
index 0000000000000000000000000000000000000000..809a817e67446b3c0c7894dcefb3c4bbc29afb7e
--- /dev/null
+++ b/mmdet/core/export/pytorch2onnx.py
@@ -0,0 +1,154 @@
+from functools import partial
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.runner import load_checkpoint
+
+
+def generate_inputs_and_wrap_model(config_path,
+ checkpoint_path,
+ input_config,
+ cfg_options=None):
+ """Prepare sample input and wrap model for ONNX export.
+
+ The ONNX export API only accept args, and all inputs should be
+ torch.Tensor or corresponding types (such as tuple of tensor).
+ So we should call this function before exporting. This function will:
+
+ 1. generate corresponding inputs which are used to execute the model.
+ 2. Wrap the model's forward function.
+
+ For example, the MMDet models' forward function has a parameter
+ ``return_loss:bool``. As we want to set it as False while export API
+ supports neither bool type or kwargs. So we have to replace the forward
+ like: ``model.forward = partial(model.forward, return_loss=False)``
+
+ Args:
+ config_path (str): the OpenMMLab config for the model we want to
+ export to ONNX
+ checkpoint_path (str): Path to the corresponding checkpoint
+ input_config (dict): the exactly data in this dict depends on the
+ framework. For MMSeg, we can just declare the input shape,
+ and generate the dummy data accordingly. However, for MMDet,
+ we may pass the real img path, or the NMS will return None
+ as there is no legal bbox.
+
+ Returns:
+ tuple: (model, tensor_data) wrapped model which can be called by \
+ model(*tensor_data) and a list of inputs which are used to execute \
+ the model while exporting.
+ """
+
+ model = build_model_from_cfg(
+ config_path, checkpoint_path, cfg_options=cfg_options)
+ one_img, one_meta = preprocess_example_input(input_config)
+ tensor_data = [one_img]
+ model.forward = partial(
+ model.forward, img_metas=[[one_meta]], return_loss=False)
+
+ # pytorch has some bug in pytorch1.3, we have to fix it
+ # by replacing these existing op
+ opset_version = 11
+ # put the import within the function thus it will not cause import error
+ # when not using this function
+ try:
+ from mmcv.onnx.symbolic import register_extra_symbolics
+ except ModuleNotFoundError:
+ raise NotImplementedError('please update mmcv to version>=v1.0.4')
+ register_extra_symbolics(opset_version)
+
+ return model, tensor_data
+
+
+def build_model_from_cfg(config_path, checkpoint_path, cfg_options=None):
+ """Build a model from config and load the given checkpoint.
+
+ Args:
+ config_path (str): the OpenMMLab config for the model we want to
+ export to ONNX
+ checkpoint_path (str): Path to the corresponding checkpoint
+
+ Returns:
+ torch.nn.Module: the built model
+ """
+ from mmdet.models import build_detector
+
+ cfg = mmcv.Config.fromfile(config_path)
+ if cfg_options is not None:
+ cfg.merge_from_dict(cfg_options)
+ # import modules from string list.
+ if cfg.get('custom_imports', None):
+ from mmcv.utils import import_modules_from_strings
+ import_modules_from_strings(**cfg['custom_imports'])
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ cfg.model.pretrained = None
+ cfg.data.test.test_mode = True
+
+ # build the model
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
+ load_checkpoint(model, checkpoint_path, map_location='cpu')
+ model.cpu().eval()
+ return model
+
+
+def preprocess_example_input(input_config):
+ """Prepare an example input image for ``generate_inputs_and_wrap_model``.
+
+ Args:
+ input_config (dict): customized config describing the example input.
+
+ Returns:
+ tuple: (one_img, one_meta), tensor of the example input image and \
+ meta information for the example input image.
+
+ Examples:
+ >>> from mmdet.core.export import preprocess_example_input
+ >>> input_config = {
+ >>> 'input_shape': (1,3,224,224),
+ >>> 'input_path': 'demo/demo.jpg',
+ >>> 'normalize_cfg': {
+ >>> 'mean': (123.675, 116.28, 103.53),
+ >>> 'std': (58.395, 57.12, 57.375)
+ >>> }
+ >>> }
+ >>> one_img, one_meta = preprocess_example_input(input_config)
+ >>> print(one_img.shape)
+ torch.Size([1, 3, 224, 224])
+ >>> print(one_meta)
+ {'img_shape': (224, 224, 3),
+ 'ori_shape': (224, 224, 3),
+ 'pad_shape': (224, 224, 3),
+ 'filename': '.png',
+ 'scale_factor': 1.0,
+ 'flip': False}
+ """
+ input_path = input_config['input_path']
+ input_shape = input_config['input_shape']
+ one_img = mmcv.imread(input_path)
+ one_img = mmcv.imresize(one_img, input_shape[2:][::-1])
+ show_img = one_img.copy()
+ if 'normalize_cfg' in input_config.keys():
+ normalize_cfg = input_config['normalize_cfg']
+ mean = np.array(normalize_cfg['mean'], dtype=np.float32)
+ std = np.array(normalize_cfg['std'], dtype=np.float32)
+ to_rgb = normalize_cfg.get('to_rgb', True)
+ one_img = mmcv.imnormalize(one_img, mean, std, to_rgb=to_rgb)
+ one_img = one_img.transpose(2, 0, 1)
+ one_img = torch.from_numpy(one_img).unsqueeze(0).float().requires_grad_(
+ True)
+ (_, C, H, W) = input_shape
+ one_meta = {
+ 'img_shape': (H, W, C),
+ 'ori_shape': (H, W, C),
+ 'pad_shape': (H, W, C),
+ 'filename': '.png',
+ 'scale_factor': 1.0,
+ 'flip': False,
+ 'show_img': show_img,
+ }
+
+ return one_img, one_meta
diff --git a/mmdet/core/mask/__init__.py b/mmdet/core/mask/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab1e88bc686d5c2fe72b3114cb2b3e372e73a0f8
--- /dev/null
+++ b/mmdet/core/mask/__init__.py
@@ -0,0 +1,8 @@
+from .mask_target import mask_target
+from .structures import BaseInstanceMasks, BitmapMasks, PolygonMasks
+from .utils import encode_mask_results, split_combined_polys
+
+__all__ = [
+ 'split_combined_polys', 'mask_target', 'BaseInstanceMasks', 'BitmapMasks',
+ 'PolygonMasks', 'encode_mask_results'
+]
diff --git a/mmdet/core/mask/mask_target.py b/mmdet/core/mask/mask_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..15d26a88bbf3710bd92813335918407db8c4e053
--- /dev/null
+++ b/mmdet/core/mask/mask_target.py
@@ -0,0 +1,122 @@
+import numpy as np
+import torch
+from torch.nn.modules.utils import _pair
+
+
+def mask_target(pos_proposals_list, pos_assigned_gt_inds_list, gt_masks_list,
+ cfg):
+ """Compute mask target for positive proposals in multiple images.
+
+ Args:
+ pos_proposals_list (list[Tensor]): Positive proposals in multiple
+ images.
+ pos_assigned_gt_inds_list (list[Tensor]): Assigned GT indices for each
+ positive proposals.
+ gt_masks_list (list[:obj:`BaseInstanceMasks`]): Ground truth masks of
+ each image.
+ cfg (dict): Config dict that specifies the mask size.
+
+ Returns:
+ list[Tensor]: Mask target of each image.
+
+ Example:
+ >>> import mmcv
+ >>> import mmdet
+ >>> from mmdet.core.mask import BitmapMasks
+ >>> from mmdet.core.mask.mask_target import *
+ >>> H, W = 17, 18
+ >>> cfg = mmcv.Config({'mask_size': (13, 14)})
+ >>> rng = np.random.RandomState(0)
+ >>> # Positive proposals (tl_x, tl_y, br_x, br_y) for each image
+ >>> pos_proposals_list = [
+ >>> torch.Tensor([
+ >>> [ 7.2425, 5.5929, 13.9414, 14.9541],
+ >>> [ 7.3241, 3.6170, 16.3850, 15.3102],
+ >>> ]),
+ >>> torch.Tensor([
+ >>> [ 4.8448, 6.4010, 7.0314, 9.7681],
+ >>> [ 5.9790, 2.6989, 7.4416, 4.8580],
+ >>> [ 0.0000, 0.0000, 0.1398, 9.8232],
+ >>> ]),
+ >>> ]
+ >>> # Corresponding class index for each proposal for each image
+ >>> pos_assigned_gt_inds_list = [
+ >>> torch.LongTensor([7, 0]),
+ >>> torch.LongTensor([5, 4, 1]),
+ >>> ]
+ >>> # Ground truth mask for each true object for each image
+ >>> gt_masks_list = [
+ >>> BitmapMasks(rng.rand(8, H, W), height=H, width=W),
+ >>> BitmapMasks(rng.rand(6, H, W), height=H, width=W),
+ >>> ]
+ >>> mask_targets = mask_target(
+ >>> pos_proposals_list, pos_assigned_gt_inds_list,
+ >>> gt_masks_list, cfg)
+ >>> assert mask_targets.shape == (5,) + cfg['mask_size']
+ """
+ cfg_list = [cfg for _ in range(len(pos_proposals_list))]
+ mask_targets = map(mask_target_single, pos_proposals_list,
+ pos_assigned_gt_inds_list, gt_masks_list, cfg_list)
+ mask_targets = list(mask_targets)
+ if len(mask_targets) > 0:
+ mask_targets = torch.cat(mask_targets)
+ return mask_targets
+
+
+def mask_target_single(pos_proposals, pos_assigned_gt_inds, gt_masks, cfg):
+ """Compute mask target for each positive proposal in the image.
+
+ Args:
+ pos_proposals (Tensor): Positive proposals.
+ pos_assigned_gt_inds (Tensor): Assigned GT inds of positive proposals.
+ gt_masks (:obj:`BaseInstanceMasks`): GT masks in the format of Bitmap
+ or Polygon.
+ cfg (dict): Config dict that indicate the mask size.
+
+ Returns:
+ Tensor: Mask target of each positive proposals in the image.
+
+ Example:
+ >>> import mmcv
+ >>> import mmdet
+ >>> from mmdet.core.mask import BitmapMasks
+ >>> from mmdet.core.mask.mask_target import * # NOQA
+ >>> H, W = 32, 32
+ >>> cfg = mmcv.Config({'mask_size': (7, 11)})
+ >>> rng = np.random.RandomState(0)
+ >>> # Masks for each ground truth box (relative to the image)
+ >>> gt_masks_data = rng.rand(3, H, W)
+ >>> gt_masks = BitmapMasks(gt_masks_data, height=H, width=W)
+ >>> # Predicted positive boxes in one image
+ >>> pos_proposals = torch.FloatTensor([
+ >>> [ 16.2, 5.5, 19.9, 20.9],
+ >>> [ 17.3, 13.6, 19.3, 19.3],
+ >>> [ 14.8, 16.4, 17.0, 23.7],
+ >>> [ 0.0, 0.0, 16.0, 16.0],
+ >>> [ 4.0, 0.0, 20.0, 16.0],
+ >>> ])
+ >>> # For each predicted proposal, its assignment to a gt mask
+ >>> pos_assigned_gt_inds = torch.LongTensor([0, 1, 2, 1, 1])
+ >>> mask_targets = mask_target_single(
+ >>> pos_proposals, pos_assigned_gt_inds, gt_masks, cfg)
+ >>> assert mask_targets.shape == (5,) + cfg['mask_size']
+ """
+ device = pos_proposals.device
+ mask_size = _pair(cfg.mask_size)
+ num_pos = pos_proposals.size(0)
+ if num_pos > 0:
+ proposals_np = pos_proposals.cpu().numpy()
+ maxh, maxw = gt_masks.height, gt_masks.width
+ proposals_np[:, [0, 2]] = np.clip(proposals_np[:, [0, 2]], 0, maxw)
+ proposals_np[:, [1, 3]] = np.clip(proposals_np[:, [1, 3]], 0, maxh)
+ pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
+
+ mask_targets = gt_masks.crop_and_resize(
+ proposals_np, mask_size, device=device,
+ inds=pos_assigned_gt_inds).to_ndarray()
+
+ mask_targets = torch.from_numpy(mask_targets).float().to(device)
+ else:
+ mask_targets = pos_proposals.new_zeros((0, ) + mask_size)
+
+ return mask_targets
diff --git a/mmdet/core/mask/structures.py b/mmdet/core/mask/structures.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7e7ab8620b9f21710fc8a61bdaaec20d96e5c20
--- /dev/null
+++ b/mmdet/core/mask/structures.py
@@ -0,0 +1,1042 @@
+from abc import ABCMeta, abstractmethod
+
+import cv2
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+import torch
+from mmcv.ops.roi_align import roi_align
+
+
+class BaseInstanceMasks(metaclass=ABCMeta):
+ """Base class for instance masks."""
+
+ @abstractmethod
+ def rescale(self, scale, interpolation='nearest'):
+ """Rescale masks as large as possible while keeping the aspect ratio.
+ For details can refer to `mmcv.imrescale`.
+
+ Args:
+ scale (tuple[int]): The maximum size (h, w) of rescaled mask.
+ interpolation (str): Same as :func:`mmcv.imrescale`.
+
+ Returns:
+ BaseInstanceMasks: The rescaled masks.
+ """
+
+ @abstractmethod
+ def resize(self, out_shape, interpolation='nearest'):
+ """Resize masks to the given out_shape.
+
+ Args:
+ out_shape: Target (h, w) of resized mask.
+ interpolation (str): See :func:`mmcv.imresize`.
+
+ Returns:
+ BaseInstanceMasks: The resized masks.
+ """
+
+ @abstractmethod
+ def flip(self, flip_direction='horizontal'):
+ """Flip masks alone the given direction.
+
+ Args:
+ flip_direction (str): Either 'horizontal' or 'vertical'.
+
+ Returns:
+ BaseInstanceMasks: The flipped masks.
+ """
+
+ @abstractmethod
+ def pad(self, out_shape, pad_val):
+ """Pad masks to the given size of (h, w).
+
+ Args:
+ out_shape (tuple[int]): Target (h, w) of padded mask.
+ pad_val (int): The padded value.
+
+ Returns:
+ BaseInstanceMasks: The padded masks.
+ """
+
+ @abstractmethod
+ def crop(self, bbox):
+ """Crop each mask by the given bbox.
+
+ Args:
+ bbox (ndarray): Bbox in format [x1, y1, x2, y2], shape (4, ).
+
+ Return:
+ BaseInstanceMasks: The cropped masks.
+ """
+
+ @abstractmethod
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device,
+ interpolation='bilinear'):
+ """Crop and resize masks by the given bboxes.
+
+ This function is mainly used in mask targets computation.
+ It firstly align mask to bboxes by assigned_inds, then crop mask by the
+ assigned bbox and resize to the size of (mask_h, mask_w)
+
+ Args:
+ bboxes (Tensor): Bboxes in format [x1, y1, x2, y2], shape (N, 4)
+ out_shape (tuple[int]): Target (h, w) of resized mask
+ inds (ndarray): Indexes to assign masks to each bbox,
+ shape (N,) and values should be between [0, num_masks - 1].
+ device (str): Device of bboxes
+ interpolation (str): See `mmcv.imresize`
+
+ Return:
+ BaseInstanceMasks: the cropped and resized masks.
+ """
+
+ @abstractmethod
+ def expand(self, expanded_h, expanded_w, top, left):
+ """see :class:`Expand`."""
+
+ @property
+ @abstractmethod
+ def areas(self):
+ """ndarray: areas of each instance."""
+
+ @abstractmethod
+ def to_ndarray(self):
+ """Convert masks to the format of ndarray.
+
+ Return:
+ ndarray: Converted masks in the format of ndarray.
+ """
+
+ @abstractmethod
+ def to_tensor(self, dtype, device):
+ """Convert masks to the format of Tensor.
+
+ Args:
+ dtype (str): Dtype of converted mask.
+ device (torch.device): Device of converted masks.
+
+ Returns:
+ Tensor: Converted masks in the format of Tensor.
+ """
+
+ @abstractmethod
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Translate the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ fill_val (int | float): Border value. Default 0.
+ interpolation (str): Same as :func:`mmcv.imtranslate`.
+
+ Returns:
+ Translated masks.
+ """
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The shear direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border. Default 0.
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+
+ Returns:
+ ndarray: Sheared masks.
+ """
+
+ @abstractmethod
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """Rotate the masks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ angle (int | float): Rotation angle in degrees. Positive values
+ mean counter-clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation in source image. If not specified, the center of
+ the image will be used.
+ scale (int | float): Isotropic scale factor.
+ fill_val (int | float): Border value. Default 0 for masks.
+
+ Returns:
+ Rotated masks.
+ """
+
+
+class BitmapMasks(BaseInstanceMasks):
+ """This class represents masks in the form of bitmaps.
+
+ Args:
+ masks (ndarray): ndarray of masks in shape (N, H, W), where N is
+ the number of objects.
+ height (int): height of masks
+ width (int): width of masks
+
+ Example:
+ >>> from mmdet.core.mask.structures import * # NOQA
+ >>> num_masks, H, W = 3, 32, 32
+ >>> rng = np.random.RandomState(0)
+ >>> masks = (rng.rand(num_masks, H, W) > 0.1).astype(np.int)
+ >>> self = BitmapMasks(masks, height=H, width=W)
+
+ >>> # demo crop_and_resize
+ >>> num_boxes = 5
+ >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
+ >>> out_shape = (14, 14)
+ >>> inds = torch.randint(0, len(self), size=(num_boxes,))
+ >>> device = 'cpu'
+ >>> interpolation = 'bilinear'
+ >>> new = self.crop_and_resize(
+ ... bboxes, out_shape, inds, device, interpolation)
+ >>> assert len(new) == num_boxes
+ >>> assert new.height, new.width == out_shape
+ """
+
+ def __init__(self, masks, height, width):
+ self.height = height
+ self.width = width
+ if len(masks) == 0:
+ self.masks = np.empty((0, self.height, self.width), dtype=np.uint8)
+ else:
+ assert isinstance(masks, (list, np.ndarray))
+ if isinstance(masks, list):
+ assert isinstance(masks[0], np.ndarray)
+ assert masks[0].ndim == 2 # (H, W)
+ else:
+ assert masks.ndim == 3 or masks.ndim == 4# (N, H, W)
+
+ self.masks = np.stack(masks).reshape(-1, height, width)
+ assert self.masks.shape[1] == self.height
+ assert self.masks.shape[2] == self.width
+
+ def __getitem__(self, index):
+ """Index the BitmapMask.
+
+ Args:
+ index (int | ndarray): Indices in the format of integer or ndarray.
+
+ Returns:
+ :obj:`BitmapMasks`: Indexed bitmap masks.
+ """
+ try:
+ masks = self.masks[index].reshape(-1, self.height, self.width)
+ except:
+ masks = self.masks[index].reshape(-1, self.height, self.width)
+
+ return BitmapMasks(masks, self.height, self.width)
+
+ def __iter__(self):
+ return iter(self.masks)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += f'num_masks={len(self.masks)}, '
+ s += f'height={self.height}, '
+ s += f'width={self.width})'
+ return s
+
+ def __len__(self):
+ """Number of masks."""
+ return len(self.masks)
+
+ def rescale(self, scale, interpolation='nearest'):
+ """See :func:`BaseInstanceMasks.rescale`."""
+ if len(self.masks) == 0:
+ new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
+ rescaled_masks = np.empty((0, new_h, new_w), dtype=np.uint8)
+ else:
+ rescaled_masks = np.stack([
+ mmcv.imrescale(mask, scale, interpolation=interpolation)
+ for mask in self.masks
+ ])
+ height, width = rescaled_masks.shape[1:]
+ return BitmapMasks(rescaled_masks, height, width)
+
+ def resize(self, out_shape, interpolation='nearest'):
+ """See :func:`BaseInstanceMasks.resize`."""
+ if len(self.masks) == 0:
+ resized_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ resized_masks = np.stack([
+ mmcv.imresize(
+ mask, out_shape[::-1], interpolation=interpolation)
+ for mask in self.masks
+ ])
+ return BitmapMasks(resized_masks, *out_shape)
+
+ def flip(self, flip_direction='horizontal'):
+ """See :func:`BaseInstanceMasks.flip`."""
+ assert flip_direction in ('horizontal', 'vertical', 'diagonal')
+
+ if len(self.masks) == 0:
+ flipped_masks = self.masks
+ else:
+ flipped_masks = np.stack([
+ mmcv.imflip(mask, direction=flip_direction)
+ for mask in self.masks
+ ])
+ return BitmapMasks(flipped_masks, self.height, self.width)
+
+ def pad(self, out_shape, pad_val=0):
+ """See :func:`BaseInstanceMasks.pad`."""
+ if len(self.masks) == 0:
+ padded_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ padded_masks = np.stack([
+ mmcv.impad(mask, shape=out_shape, pad_val=pad_val)
+ for mask in self.masks
+ ])
+ return BitmapMasks(padded_masks, *out_shape)
+
+ def crop(self, bbox):
+ """See :func:`BaseInstanceMasks.crop`."""
+ assert isinstance(bbox, np.ndarray)
+ assert bbox.ndim == 1
+
+ # clip the boundary
+ bbox = bbox.copy()
+ bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+ bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+
+ if len(self.masks) == 0:
+ cropped_masks = np.empty((0, h, w), dtype=np.uint8)
+ else:
+ cropped_masks = self.masks[:, y1:y1 + h, x1:x1 + w]
+ return BitmapMasks(cropped_masks, h, w)
+
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device='cpu',
+ interpolation='bilinear'):
+ """See :func:`BaseInstanceMasks.crop_and_resize`."""
+ if len(self.masks) == 0:
+ empty_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ return BitmapMasks(empty_masks, *out_shape)
+
+ # convert bboxes to tensor
+ if isinstance(bboxes, np.ndarray):
+ bboxes = torch.from_numpy(bboxes).to(device=device)
+ if isinstance(inds, np.ndarray):
+ inds = torch.from_numpy(inds).to(device=device)
+
+ num_bbox = bboxes.shape[0]
+ fake_inds = torch.arange(
+ num_bbox, device=device).to(dtype=bboxes.dtype)[:, None]
+ rois = torch.cat([fake_inds, bboxes], dim=1) # Nx5
+ rois = rois.to(device=device)
+ if num_bbox > 0:
+ #masks_vis = (self.masks == 1)
+ masks_vis = (self.masks > 0)
+ gt_masks_th = torch.from_numpy(masks_vis).to(device).index_select(
+ 0, inds).to(dtype=rois.dtype)
+ targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
+ 1.0, 0, 'avg', True).squeeze(1)
+ targets = targets.cpu().numpy().astype(int)
+ resized_masks_vis = (targets > 0.5)
+
+ #masks_full = (self.masks > 0)
+ masks_full = (self.masks == 2)
+ #masks_occ = (self.masks == 2)
+ gt_masks_th = torch.from_numpy(masks_full).to(device).index_select(
+ 0, inds).to(dtype=rois.dtype)
+ targets = roi_align(gt_masks_th[:, None, :, :], rois, out_shape,
+ 1.0, 0, 'avg', True).squeeze(1)
+ targets = targets.cpu().numpy().astype(int)
+ resized_masks_full = (targets > 0.5)
+ resized_masks = np.stack([resized_masks_vis,resized_masks_full],axis=1)
+ else:
+ resized_masks = []
+ return BitmapMasks(resized_masks, *out_shape)
+
+ def expand(self, expanded_h, expanded_w, top, left):
+ """See :func:`BaseInstanceMasks.expand`."""
+ if len(self.masks) == 0:
+ expanded_mask = np.empty((0, expanded_h, expanded_w),
+ dtype=np.uint8)
+ else:
+ expanded_mask = np.zeros((len(self), expanded_h, expanded_w),
+ dtype=np.uint8)
+ expanded_mask[:, top:top + self.height,
+ left:left + self.width] = self.masks
+ return BitmapMasks(expanded_mask, expanded_h, expanded_w)
+
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Translate the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ fill_val (int | float): Border value. Default 0 for masks.
+ interpolation (str): Same as :func:`mmcv.imtranslate`.
+
+ Returns:
+ BitmapMasks: Translated BitmapMasks.
+
+ Example:
+ >>> from mmdet.core.mask.structures import BitmapMasks
+ >>> self = BitmapMasks.random(dtype=np.uint8)
+ >>> out_shape = (32, 32)
+ >>> offset = 4
+ >>> direction = 'horizontal'
+ >>> fill_val = 0
+ >>> interpolation = 'bilinear'
+ >>> # Note, There seem to be issues when:
+ >>> # * out_shape is different than self's shape
+ >>> # * the mask dtype is not supported by cv2.AffineWarp
+ >>> new = self.translate(out_shape, offset, direction, fill_val,
+ >>> interpolation)
+ >>> assert len(new) == len(self)
+ >>> assert new.height, new.width == out_shape
+ """
+ if len(self.masks) == 0:
+ translated_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ translated_masks = mmcv.imtranslate(
+ self.masks.transpose((1, 2, 0)),
+ offset,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation)
+ if translated_masks.ndim == 2:
+ translated_masks = translated_masks[:, :, None]
+ translated_masks = translated_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(translated_masks, *out_shape)
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """Shear the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The shear direction, either "horizontal"
+ or "vertical".
+ border_value (int | tuple[int]): Value used in case of a
+ constant border.
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+
+ Returns:
+ BitmapMasks: The sheared masks.
+ """
+ if len(self.masks) == 0:
+ sheared_masks = np.empty((0, *out_shape), dtype=np.uint8)
+ else:
+ sheared_masks = mmcv.imshear(
+ self.masks.transpose((1, 2, 0)),
+ magnitude,
+ direction,
+ border_value=border_value,
+ interpolation=interpolation)
+ if sheared_masks.ndim == 2:
+ sheared_masks = sheared_masks[:, :, None]
+ sheared_masks = sheared_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(sheared_masks, *out_shape)
+
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """Rotate the BitmapMasks.
+
+ Args:
+ out_shape (tuple[int]): Shape for output mask, format (h, w).
+ angle (int | float): Rotation angle in degrees. Positive values
+ mean counter-clockwise rotation.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation in source image. If not specified, the center of
+ the image will be used.
+ scale (int | float): Isotropic scale factor.
+ fill_val (int | float): Border value. Default 0 for masks.
+
+ Returns:
+ BitmapMasks: Rotated BitmapMasks.
+ """
+ if len(self.masks) == 0:
+ rotated_masks = np.empty((0, *out_shape), dtype=self.masks.dtype)
+ else:
+ rotated_masks = mmcv.imrotate(
+ self.masks.transpose((1, 2, 0)),
+ angle,
+ center=center,
+ scale=scale,
+ border_value=fill_val)
+ if rotated_masks.ndim == 2:
+ # case when only one mask, (h, w)
+ rotated_masks = rotated_masks[:, :, None] # (h, w, 1)
+ rotated_masks = rotated_masks.transpose(
+ (2, 0, 1)).astype(self.masks.dtype)
+ return BitmapMasks(rotated_masks, *out_shape)
+
+ @property
+ def areas(self):
+ """See :py:attr:`BaseInstanceMasks.areas`."""
+ return self.masks.sum((1, 2))
+
+ def to_ndarray(self):
+ """See :func:`BaseInstanceMasks.to_ndarray`."""
+ return self.masks
+
+ def to_tensor(self, dtype, device):
+ """See :func:`BaseInstanceMasks.to_tensor`."""
+ return torch.tensor(self.masks, dtype=dtype, device=device)
+
+ @classmethod
+ def random(cls,
+ num_masks=3,
+ height=32,
+ width=32,
+ dtype=np.uint8,
+ rng=None):
+ """Generate random bitmap masks for demo / testing purposes.
+
+ Example:
+ >>> from mmdet.core.mask.structures import BitmapMasks
+ >>> self = BitmapMasks.random()
+ >>> print('self = {}'.format(self))
+ self = BitmapMasks(num_masks=3, height=32, width=32)
+ """
+ from mmdet.utils.util_random import ensure_rng
+ rng = ensure_rng(rng)
+ masks = (rng.rand(num_masks, height, width) > 0.1).astype(dtype)
+ self = cls(masks, height=height, width=width)
+ return self
+
+
+class PolygonMasks(BaseInstanceMasks):
+ """This class represents masks in the form of polygons.
+
+ Polygons is a list of three levels. The first level of the list
+ corresponds to objects, the second level to the polys that compose the
+ object, the third level to the poly coordinates
+
+ Args:
+ masks (list[list[ndarray]]): The first level of the list
+ corresponds to objects, the second level to the polys that
+ compose the object, the third level to the poly coordinates
+ height (int): height of masks
+ width (int): width of masks
+
+ Example:
+ >>> from mmdet.core.mask.structures import * # NOQA
+ >>> masks = [
+ >>> [ np.array([0, 0, 10, 0, 10, 10., 0, 10, 0, 0]) ]
+ >>> ]
+ >>> height, width = 16, 16
+ >>> self = PolygonMasks(masks, height, width)
+
+ >>> # demo translate
+ >>> new = self.translate((16, 16), 4., direction='horizontal')
+ >>> assert np.all(new.masks[0][0][1::2] == masks[0][0][1::2])
+ >>> assert np.all(new.masks[0][0][0::2] == masks[0][0][0::2] + 4)
+
+ >>> # demo crop_and_resize
+ >>> num_boxes = 3
+ >>> bboxes = np.array([[0, 0, 30, 10.0]] * num_boxes)
+ >>> out_shape = (16, 16)
+ >>> inds = torch.randint(0, len(self), size=(num_boxes,))
+ >>> device = 'cpu'
+ >>> interpolation = 'bilinear'
+ >>> new = self.crop_and_resize(
+ ... bboxes, out_shape, inds, device, interpolation)
+ >>> assert len(new) == num_boxes
+ >>> assert new.height, new.width == out_shape
+ """
+
+ def __init__(self, masks, height, width):
+ assert isinstance(masks, list)
+ if len(masks) > 0:
+ assert isinstance(masks[0], list)
+ assert isinstance(masks[0][0], np.ndarray)
+
+ self.height = height
+ self.width = width
+ self.masks = masks
+
+ def __getitem__(self, index):
+ """Index the polygon masks.
+
+ Args:
+ index (ndarray | List): The indices.
+
+ Returns:
+ :obj:`PolygonMasks`: The indexed polygon masks.
+ """
+ if isinstance(index, np.ndarray):
+ index = index.tolist()
+ if isinstance(index, list):
+ masks = [self.masks[i] for i in index]
+ else:
+ try:
+ masks = self.masks[index]
+ except Exception:
+ raise ValueError(
+ f'Unsupported input of type {type(index)} for indexing!')
+ if len(masks) and isinstance(masks[0], np.ndarray):
+ masks = [masks] # ensure a list of three levels
+ return PolygonMasks(masks, self.height, self.width)
+
+ def __iter__(self):
+ return iter(self.masks)
+
+ def __repr__(self):
+ s = self.__class__.__name__ + '('
+ s += f'num_masks={len(self.masks)}, '
+ s += f'height={self.height}, '
+ s += f'width={self.width})'
+ return s
+
+ def __len__(self):
+ """Number of masks."""
+ return len(self.masks)
+
+ def rescale(self, scale, interpolation=None):
+ """see :func:`BaseInstanceMasks.rescale`"""
+ new_w, new_h = mmcv.rescale_size((self.width, self.height), scale)
+ if len(self.masks) == 0:
+ rescaled_masks = PolygonMasks([], new_h, new_w)
+ else:
+ rescaled_masks = self.resize((new_h, new_w))
+ return rescaled_masks
+
+ def resize(self, out_shape, interpolation=None):
+ """see :func:`BaseInstanceMasks.resize`"""
+ if len(self.masks) == 0:
+ resized_masks = PolygonMasks([], *out_shape)
+ else:
+ h_scale = out_shape[0] / self.height
+ w_scale = out_shape[1] / self.width
+ resized_masks = []
+ for poly_per_obj in self.masks:
+ resized_poly = []
+ for p in poly_per_obj:
+ p = p.copy()
+ p[0::2] *= w_scale
+ p[1::2] *= h_scale
+ resized_poly.append(p)
+ resized_masks.append(resized_poly)
+ resized_masks = PolygonMasks(resized_masks, *out_shape)
+ return resized_masks
+
+ def flip(self, flip_direction='horizontal'):
+ """see :func:`BaseInstanceMasks.flip`"""
+ assert flip_direction in ('horizontal', 'vertical', 'diagonal')
+ if len(self.masks) == 0:
+ flipped_masks = PolygonMasks([], self.height, self.width)
+ else:
+ flipped_masks = []
+ for poly_per_obj in self.masks:
+ flipped_poly_per_obj = []
+ for p in poly_per_obj:
+ p = p.copy()
+ if flip_direction == 'horizontal':
+ p[0::2] = self.width - p[0::2]
+ elif flip_direction == 'vertical':
+ p[1::2] = self.height - p[1::2]
+ else:
+ p[0::2] = self.width - p[0::2]
+ p[1::2] = self.height - p[1::2]
+ flipped_poly_per_obj.append(p)
+ flipped_masks.append(flipped_poly_per_obj)
+ flipped_masks = PolygonMasks(flipped_masks, self.height,
+ self.width)
+ return flipped_masks
+
+ def crop(self, bbox):
+ """see :func:`BaseInstanceMasks.crop`"""
+ assert isinstance(bbox, np.ndarray)
+ assert bbox.ndim == 1
+
+ # clip the boundary
+ bbox = bbox.copy()
+ bbox[0::2] = np.clip(bbox[0::2], 0, self.width)
+ bbox[1::2] = np.clip(bbox[1::2], 0, self.height)
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+
+ if len(self.masks) == 0:
+ cropped_masks = PolygonMasks([], h, w)
+ else:
+ cropped_masks = []
+ for poly_per_obj in self.masks:
+ cropped_poly_per_obj = []
+ for p in poly_per_obj:
+ # pycocotools will clip the boundary
+ p = p.copy()
+ p[0::2] -= bbox[0]
+ p[1::2] -= bbox[1]
+ cropped_poly_per_obj.append(p)
+ cropped_masks.append(cropped_poly_per_obj)
+ cropped_masks = PolygonMasks(cropped_masks, h, w)
+ return cropped_masks
+
+ def pad(self, out_shape, pad_val=0):
+ """padding has no effect on polygons`"""
+ return PolygonMasks(self.masks, *out_shape)
+
+ def expand(self, *args, **kwargs):
+ """TODO: Add expand for polygon"""
+ raise NotImplementedError
+
+ def crop_and_resize(self,
+ bboxes,
+ out_shape,
+ inds,
+ device='cpu',
+ interpolation='bilinear'):
+ """see :func:`BaseInstanceMasks.crop_and_resize`"""
+ out_h, out_w = out_shape
+ if len(self.masks) == 0:
+ return PolygonMasks([], out_h, out_w)
+
+ resized_masks = []
+ for i in range(len(bboxes)):
+ mask = self.masks[inds[i]]
+ bbox = bboxes[i, :]
+ x1, y1, x2, y2 = bbox
+ w = np.maximum(x2 - x1, 1)
+ h = np.maximum(y2 - y1, 1)
+ h_scale = out_h / max(h, 0.1) # avoid too large scale
+ w_scale = out_w / max(w, 0.1)
+
+ resized_mask = []
+ for p in mask:
+ p = p.copy()
+ # crop
+ # pycocotools will clip the boundary
+ p[0::2] -= bbox[0]
+ p[1::2] -= bbox[1]
+
+ # resize
+ p[0::2] *= w_scale
+ p[1::2] *= h_scale
+ resized_mask.append(p)
+ resized_masks.append(resized_mask)
+ return PolygonMasks(resized_masks, *out_shape)
+
+ def translate(self,
+ out_shape,
+ offset,
+ direction='horizontal',
+ fill_val=None,
+ interpolation=None):
+ """Translate the PolygonMasks.
+
+ Example:
+ >>> self = PolygonMasks.random(dtype=np.int)
+ >>> out_shape = (self.height, self.width)
+ >>> new = self.translate(out_shape, 4., direction='horizontal')
+ >>> assert np.all(new.masks[0][0][1::2] == self.masks[0][0][1::2])
+ >>> assert np.all(new.masks[0][0][0::2] == self.masks[0][0][0::2] + 4) # noqa: E501
+ """
+ assert fill_val is None or fill_val == 0, 'Here fill_val is not '\
+ f'used, and defaultly should be None or 0. got {fill_val}.'
+ if len(self.masks) == 0:
+ translated_masks = PolygonMasks([], *out_shape)
+ else:
+ translated_masks = []
+ for poly_per_obj in self.masks:
+ translated_poly_per_obj = []
+ for p in poly_per_obj:
+ p = p.copy()
+ if direction == 'horizontal':
+ p[0::2] = np.clip(p[0::2] + offset, 0, out_shape[1])
+ elif direction == 'vertical':
+ p[1::2] = np.clip(p[1::2] + offset, 0, out_shape[0])
+ translated_poly_per_obj.append(p)
+ translated_masks.append(translated_poly_per_obj)
+ translated_masks = PolygonMasks(translated_masks, *out_shape)
+ return translated_masks
+
+ def shear(self,
+ out_shape,
+ magnitude,
+ direction='horizontal',
+ border_value=0,
+ interpolation='bilinear'):
+ """See :func:`BaseInstanceMasks.shear`."""
+ if len(self.masks) == 0:
+ sheared_masks = PolygonMasks([], *out_shape)
+ else:
+ sheared_masks = []
+ if direction == 'horizontal':
+ shear_matrix = np.stack([[1, magnitude],
+ [0, 1]]).astype(np.float32)
+ elif direction == 'vertical':
+ shear_matrix = np.stack([[1, 0], [magnitude,
+ 1]]).astype(np.float32)
+ for poly_per_obj in self.masks:
+ sheared_poly = []
+ for p in poly_per_obj:
+ p = np.stack([p[0::2], p[1::2]], axis=0) # [2, n]
+ new_coords = np.matmul(shear_matrix, p) # [2, n]
+ new_coords[0, :] = np.clip(new_coords[0, :], 0,
+ out_shape[1])
+ new_coords[1, :] = np.clip(new_coords[1, :], 0,
+ out_shape[0])
+ sheared_poly.append(
+ new_coords.transpose((1, 0)).reshape(-1))
+ sheared_masks.append(sheared_poly)
+ sheared_masks = PolygonMasks(sheared_masks, *out_shape)
+ return sheared_masks
+
+ def rotate(self, out_shape, angle, center=None, scale=1.0, fill_val=0):
+ """See :func:`BaseInstanceMasks.rotate`."""
+ if len(self.masks) == 0:
+ rotated_masks = PolygonMasks([], *out_shape)
+ else:
+ rotated_masks = []
+ rotate_matrix = cv2.getRotationMatrix2D(center, -angle, scale)
+ for poly_per_obj in self.masks:
+ rotated_poly = []
+ for p in poly_per_obj:
+ p = p.copy()
+ coords = np.stack([p[0::2], p[1::2]], axis=1) # [n, 2]
+ # pad 1 to convert from format [x, y] to homogeneous
+ # coordinates format [x, y, 1]
+ coords = np.concatenate(
+ (coords, np.ones((coords.shape[0], 1), coords.dtype)),
+ axis=1) # [n, 3]
+ rotated_coords = np.matmul(
+ rotate_matrix[None, :, :],
+ coords[:, :, None])[..., 0] # [n, 2, 1] -> [n, 2]
+ rotated_coords[:, 0] = np.clip(rotated_coords[:, 0], 0,
+ out_shape[1])
+ rotated_coords[:, 1] = np.clip(rotated_coords[:, 1], 0,
+ out_shape[0])
+ rotated_poly.append(rotated_coords.reshape(-1))
+ rotated_masks.append(rotated_poly)
+ rotated_masks = PolygonMasks(rotated_masks, *out_shape)
+ return rotated_masks
+
+ def to_bitmap(self):
+ """convert polygon masks to bitmap masks."""
+ bitmap_masks = self.to_ndarray()
+ return BitmapMasks(bitmap_masks, self.height, self.width)
+
+ @property
+ def areas(self):
+ """Compute areas of masks.
+
+ This func is modified from `detectron2
+ `_.
+ The function only works with Polygons using the shoelace formula.
+
+ Return:
+ ndarray: areas of each instance
+ """ # noqa: W501
+ area = []
+ for polygons_per_obj in self.masks:
+ area_per_obj = 0
+ for p in polygons_per_obj:
+ area_per_obj += self._polygon_area(p[0::2], p[1::2])
+ area.append(area_per_obj)
+ return np.asarray(area)
+
+ def _polygon_area(self, x, y):
+ """Compute the area of a component of a polygon.
+
+ Using the shoelace formula:
+ https://stackoverflow.com/questions/24467972/calculate-area-of-polygon-given-x-y-coordinates
+
+ Args:
+ x (ndarray): x coordinates of the component
+ y (ndarray): y coordinates of the component
+
+ Return:
+ float: the are of the component
+ """ # noqa: 501
+ return 0.5 * np.abs(
+ np.dot(x, np.roll(y, 1)) - np.dot(y, np.roll(x, 1)))
+
+ def to_ndarray(self):
+ """Convert masks to the format of ndarray."""
+ if len(self.masks) == 0:
+ return np.empty((0, self.height, self.width), dtype=np.uint8)
+ bitmap_masks = []
+ for poly_per_obj in self.masks:
+ bitmap_masks.append(
+ polygon_to_bitmap(poly_per_obj, self.height, self.width))
+ return np.stack(bitmap_masks)
+
+ def to_tensor(self, dtype, device):
+ """See :func:`BaseInstanceMasks.to_tensor`."""
+ if len(self.masks) == 0:
+ return torch.empty((0, self.height, self.width),
+ dtype=dtype,
+ device=device)
+ ndarray_masks = self.to_ndarray()
+ return torch.tensor(ndarray_masks, dtype=dtype, device=device)
+
+ @classmethod
+ def random(cls,
+ num_masks=3,
+ height=32,
+ width=32,
+ n_verts=5,
+ dtype=np.float32,
+ rng=None):
+ """Generate random polygon masks for demo / testing purposes.
+
+ Adapted from [1]_
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwimage/-/blob/928cae35ca8/kwimage/structs/polygon.py#L379 # noqa: E501
+
+ Example:
+ >>> from mmdet.core.mask.structures import PolygonMasks
+ >>> self = PolygonMasks.random()
+ >>> print('self = {}'.format(self))
+ """
+ from mmdet.utils.util_random import ensure_rng
+ rng = ensure_rng(rng)
+
+ def _gen_polygon(n, irregularity, spikeyness):
+ """Creates the polygon by sampling points on a circle around the
+ centre. Random noise is added by varying the angular spacing
+ between sequential points, and by varying the radial distance of
+ each point from the centre.
+
+ Based on original code by Mike Ounsworth
+
+ Args:
+ n (int): number of vertices
+ irregularity (float): [0,1] indicating how much variance there
+ is in the angular spacing of vertices. [0,1] will map to
+ [0, 2pi/numberOfVerts]
+ spikeyness (float): [0,1] indicating how much variance there is
+ in each vertex from the circle of radius aveRadius. [0,1]
+ will map to [0, aveRadius]
+
+ Returns:
+ a list of vertices, in CCW order.
+ """
+ from scipy.stats import truncnorm
+ # Generate around the unit circle
+ cx, cy = (0.0, 0.0)
+ radius = 1
+
+ tau = np.pi * 2
+
+ irregularity = np.clip(irregularity, 0, 1) * 2 * np.pi / n
+ spikeyness = np.clip(spikeyness, 1e-9, 1)
+
+ # generate n angle steps
+ lower = (tau / n) - irregularity
+ upper = (tau / n) + irregularity
+ angle_steps = rng.uniform(lower, upper, n)
+
+ # normalize the steps so that point 0 and point n+1 are the same
+ k = angle_steps.sum() / (2 * np.pi)
+ angles = (angle_steps / k).cumsum() + rng.uniform(0, tau)
+
+ # Convert high and low values to be wrt the standard normal range
+ # https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.truncnorm.html
+ low = 0
+ high = 2 * radius
+ mean = radius
+ std = spikeyness
+ a = (low - mean) / std
+ b = (high - mean) / std
+ tnorm = truncnorm(a=a, b=b, loc=mean, scale=std)
+
+ # now generate the points
+ radii = tnorm.rvs(n, random_state=rng)
+ x_pts = cx + radii * np.cos(angles)
+ y_pts = cy + radii * np.sin(angles)
+
+ points = np.hstack([x_pts[:, None], y_pts[:, None]])
+
+ # Scale to 0-1 space
+ points = points - points.min(axis=0)
+ points = points / points.max(axis=0)
+
+ # Randomly place within 0-1 space
+ points = points * (rng.rand() * .8 + .2)
+ min_pt = points.min(axis=0)
+ max_pt = points.max(axis=0)
+
+ high = (1 - max_pt)
+ low = (0 - min_pt)
+ offset = (rng.rand(2) * (high - low)) + low
+ points = points + offset
+ return points
+
+ def _order_vertices(verts):
+ """
+ References:
+ https://stackoverflow.com/questions/1709283/how-can-i-sort-a-coordinate-list-for-a-rectangle-counterclockwise
+ """
+ mlat = verts.T[0].sum() / len(verts)
+ mlng = verts.T[1].sum() / len(verts)
+
+ tau = np.pi * 2
+ angle = (np.arctan2(mlat - verts.T[0], verts.T[1] - mlng) +
+ tau) % tau
+ sortx = angle.argsort()
+ verts = verts.take(sortx, axis=0)
+ return verts
+
+ # Generate a random exterior for each requested mask
+ masks = []
+ for _ in range(num_masks):
+ exterior = _order_vertices(_gen_polygon(n_verts, 0.9, 0.9))
+ exterior = (exterior * [(width, height)]).astype(dtype)
+ masks.append([exterior.ravel()])
+
+ self = cls(masks, height, width)
+ return self
+
+
+def polygon_to_bitmap(polygons, height, width):
+ """Convert masks from the form of polygons to bitmaps.
+
+ Args:
+ polygons (list[ndarray]): masks in polygon representation
+ height (int): mask height
+ width (int): mask width
+
+ Return:
+ ndarray: the converted masks in bitmap representation
+ """
+ rles = maskUtils.frPyObjects(polygons, height, width)
+ rle = maskUtils.merge(rles)
+ bitmap_mask = maskUtils.decode(rle).astype(np.bool)
+ return bitmap_mask
diff --git a/mmdet/core/mask/utils.py b/mmdet/core/mask/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88208291ab2a605bee9fe6c1a28a443b74c6372
--- /dev/null
+++ b/mmdet/core/mask/utils.py
@@ -0,0 +1,63 @@
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+
+
+def split_combined_polys(polys, poly_lens, polys_per_mask):
+ """Split the combined 1-D polys into masks.
+
+ A mask is represented as a list of polys, and a poly is represented as
+ a 1-D array. In dataset, all masks are concatenated into a single 1-D
+ tensor. Here we need to split the tensor into original representations.
+
+ Args:
+ polys (list): a list (length = image num) of 1-D tensors
+ poly_lens (list): a list (length = image num) of poly length
+ polys_per_mask (list): a list (length = image num) of poly number
+ of each mask
+
+ Returns:
+ list: a list (length = image num) of list (length = mask num) of \
+ list (length = poly num) of numpy array.
+ """
+ mask_polys_list = []
+ for img_id in range(len(polys)):
+ polys_single = polys[img_id]
+ polys_lens_single = poly_lens[img_id].tolist()
+ polys_per_mask_single = polys_per_mask[img_id].tolist()
+
+ split_polys = mmcv.slice_list(polys_single, polys_lens_single)
+ mask_polys = mmcv.slice_list(split_polys, polys_per_mask_single)
+ mask_polys_list.append(mask_polys)
+ return mask_polys_list
+
+
+# TODO: move this function to more proper place
+def encode_mask_results(mask_results):
+ """Encode bitmap mask to RLE code.
+
+ Args:
+ mask_results (list | tuple[list]): bitmap mask results.
+ In mask scoring rcnn, mask_results is a tuple of (segm_results,
+ segm_cls_score).
+
+ Returns:
+ list | tuple: RLE encoded mask.
+ """
+ if isinstance(mask_results, tuple): # mask scoring
+ cls_segms, cls_mask_scores = mask_results
+ else:
+ cls_segms = mask_results
+ num_classes = len(cls_segms)
+ encoded_mask_results = [[] for _ in range(num_classes)]
+ for i in range(len(cls_segms)):
+ for cls_segm in cls_segms[i]:
+ encoded_mask_results[i].append(
+ mask_util.encode(
+ np.array(
+ cls_segm[:, :, np.newaxis], order='F',
+ dtype='uint8'))[0]) # encoded with RLE
+ if isinstance(mask_results, tuple):
+ return encoded_mask_results, cls_mask_scores
+ else:
+ return encoded_mask_results
diff --git a/mmdet/core/post_processing/__init__.py b/mmdet/core/post_processing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..880b3f06609b050aae163b2e38088c1ee4aa0998
--- /dev/null
+++ b/mmdet/core/post_processing/__init__.py
@@ -0,0 +1,8 @@
+from .bbox_nms import fast_nms, multiclass_nms
+from .merge_augs import (merge_aug_bboxes, merge_aug_masks,
+ merge_aug_proposals, merge_aug_scores)
+
+__all__ = [
+ 'multiclass_nms', 'merge_aug_proposals', 'merge_aug_bboxes',
+ 'merge_aug_scores', 'merge_aug_masks', 'fast_nms'
+]
diff --git a/mmdet/core/post_processing/bbox_nms.py b/mmdet/core/post_processing/bbox_nms.py
new file mode 100644
index 0000000000000000000000000000000000000000..966d3a6ac86637a6be90edc3aab9b6863fb87764
--- /dev/null
+++ b/mmdet/core/post_processing/bbox_nms.py
@@ -0,0 +1,168 @@
+import torch
+from mmcv.ops.nms import batched_nms
+
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+
+
+def multiclass_nms(multi_bboxes,
+ multi_scores,
+ score_thr,
+ nms_cfg,
+ max_num=-1,
+ score_factors=None,
+ return_inds=False):
+ """NMS for multi-class bboxes.
+
+ Args:
+ multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
+ multi_scores (Tensor): shape (n, #class), where the last column
+ contains scores of the background class, but this will be ignored.
+ score_thr (float): bbox threshold, bboxes with scores lower than it
+ will not be considered.
+ nms_thr (float): NMS IoU threshold
+ max_num (int, optional): if there are more than max_num bboxes after
+ NMS, only top max_num will be kept. Default to -1.
+ score_factors (Tensor, optional): The factors multiplied to scores
+ before applying NMS. Default to None.
+ return_inds (bool, optional): Whether return the indices of kept
+ bboxes. Default to False.
+
+ Returns:
+ tuple: (bboxes, labels, indices (optional)), tensors of shape (k, 5),
+ (k), and (k). Labels are 0-based.
+ """
+ num_classes = multi_scores.size(1) - 1
+ # exclude background category
+ if multi_bboxes.shape[1] > 4:
+ bboxes = multi_bboxes.view(multi_scores.size(0), -1, 4)
+ else:
+ bboxes = multi_bboxes[:, None].expand(
+ multi_scores.size(0), num_classes, 4)
+
+ scores = multi_scores[:, :-1]
+
+ labels = torch.arange(num_classes, dtype=torch.long)
+ labels = labels.view(1, -1).expand_as(scores)
+
+ bboxes = bboxes.reshape(-1, 4)
+ scores = scores.reshape(-1)
+ labels = labels.reshape(-1)
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ # remove low scoring boxes
+ valid_mask = scores > score_thr
+ # multiply score_factor after threshold to preserve more bboxes, improve
+ # mAP by 1% for YOLOv3
+ if score_factors is not None:
+ # expand the shape to match original shape of score
+ score_factors = score_factors.view(-1, 1).expand(
+ multi_scores.size(0), num_classes)
+ score_factors = score_factors.reshape(-1)
+ scores = scores * score_factors
+
+ if not torch.onnx.is_in_onnx_export():
+ # NonZero not supported in TensorRT
+ inds = valid_mask.nonzero(as_tuple=False).squeeze(1)
+ bboxes, scores, labels = bboxes[inds], scores[inds], labels[inds]
+ else:
+ # TensorRT NMS plugin has invalid output filled with -1
+ # add dummy data to make detection output correct.
+ bboxes = torch.cat([bboxes, bboxes.new_zeros(1, 4)], dim=0)
+ scores = torch.cat([scores, scores.new_zeros(1)], dim=0)
+ labels = torch.cat([labels, labels.new_zeros(1)], dim=0)
+
+ if bboxes.numel() == 0:
+ if torch.onnx.is_in_onnx_export():
+ raise RuntimeError('[ONNX Error] Can not record NMS '
+ 'as it has not been executed this time')
+ if return_inds:
+ return bboxes, labels, inds
+ else:
+ return bboxes, labels
+
+ dets, keep = batched_nms(bboxes, scores, labels, nms_cfg)
+
+ if max_num > 0:
+ dets = dets[:max_num]
+ keep = keep[:max_num]
+
+ if return_inds:
+ return dets, labels[keep], keep
+ else:
+ return dets, labels[keep]
+
+
+def fast_nms(multi_bboxes,
+ multi_scores,
+ multi_coeffs,
+ score_thr,
+ iou_thr,
+ top_k,
+ max_num=-1):
+ """Fast NMS in `YOLACT `_.
+
+ Fast NMS allows already-removed detections to suppress other detections so
+ that every instance can be decided to be kept or discarded in parallel,
+ which is not possible in traditional NMS. This relaxation allows us to
+ implement Fast NMS entirely in standard GPU-accelerated matrix operations.
+
+ Args:
+ multi_bboxes (Tensor): shape (n, #class*4) or (n, 4)
+ multi_scores (Tensor): shape (n, #class+1), where the last column
+ contains scores of the background class, but this will be ignored.
+ multi_coeffs (Tensor): shape (n, #class*coeffs_dim).
+ score_thr (float): bbox threshold, bboxes with scores lower than it
+ will not be considered.
+ iou_thr (float): IoU threshold to be considered as conflicted.
+ top_k (int): if there are more than top_k bboxes before NMS,
+ only top top_k will be kept.
+ max_num (int): if there are more than max_num bboxes after NMS,
+ only top max_num will be kept. If -1, keep all the bboxes.
+ Default: -1.
+
+ Returns:
+ tuple: (bboxes, labels, coefficients), tensors of shape (k, 5), (k, 1),
+ and (k, coeffs_dim). Labels are 0-based.
+ """
+
+ scores = multi_scores[:, :-1].t() # [#class, n]
+ scores, idx = scores.sort(1, descending=True)
+
+ idx = idx[:, :top_k].contiguous()
+ scores = scores[:, :top_k] # [#class, topk]
+ num_classes, num_dets = idx.size()
+ boxes = multi_bboxes[idx.view(-1), :].view(num_classes, num_dets, 4)
+ coeffs = multi_coeffs[idx.view(-1), :].view(num_classes, num_dets, -1)
+
+ iou = bbox_overlaps(boxes, boxes) # [#class, topk, topk]
+ iou.triu_(diagonal=1)
+ iou_max, _ = iou.max(dim=1)
+
+ # Now just filter out the ones higher than the threshold
+ keep = iou_max <= iou_thr
+
+ # Second thresholding introduces 0.2 mAP gain at negligible time cost
+ keep *= scores > score_thr
+
+ # Assign each kept detection to its corresponding class
+ classes = torch.arange(
+ num_classes, device=boxes.device)[:, None].expand_as(keep)
+ classes = classes[keep]
+
+ boxes = boxes[keep]
+ coeffs = coeffs[keep]
+ scores = scores[keep]
+
+ # Only keep the top max_num highest scores across all classes
+ scores, idx = scores.sort(0, descending=True)
+ if max_num > 0:
+ idx = idx[:max_num]
+ scores = scores[:max_num]
+
+ classes = classes[idx]
+ boxes = boxes[idx]
+ coeffs = coeffs[idx]
+
+ cls_dets = torch.cat([boxes, scores[:, None]], dim=1)
+ return cls_dets, classes, coeffs
diff --git a/mmdet/core/post_processing/merge_augs.py b/mmdet/core/post_processing/merge_augs.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbcf79d1ac20ddc32cb1605e06d253803250c855
--- /dev/null
+++ b/mmdet/core/post_processing/merge_augs.py
@@ -0,0 +1,150 @@
+import copy
+import warnings
+
+import numpy as np
+import torch
+from mmcv import ConfigDict
+from mmcv.ops import nms
+
+from ..bbox import bbox_mapping_back
+
+
+def merge_aug_proposals(aug_proposals, img_metas, cfg):
+ """Merge augmented proposals (multiscale, flip, etc.)
+
+ Args:
+ aug_proposals (list[Tensor]): proposals from different testing
+ schemes, shape (n, 5). Note that they are not rescaled to the
+ original image size.
+
+ img_metas (list[dict]): list of image info dict where each dict has:
+ 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ cfg (dict): rpn test config.
+
+ Returns:
+ Tensor: shape (n, 4), proposals corresponding to original image scale.
+ """
+
+ cfg = copy.deepcopy(cfg)
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You set max_num and ' \
+ f'max_per_img at the same time, but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ f'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
+ f'iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get ' \
+ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the nms_thr ' \
+ f'which will be deprecated.'
+
+ recovered_proposals = []
+ for proposals, img_info in zip(aug_proposals, img_metas):
+ img_shape = img_info['img_shape']
+ scale_factor = img_info['scale_factor']
+ flip = img_info['flip']
+ flip_direction = img_info['flip_direction']
+ _proposals = proposals.clone()
+ _proposals[:, :4] = bbox_mapping_back(_proposals[:, :4], img_shape,
+ scale_factor, flip,
+ flip_direction)
+ recovered_proposals.append(_proposals)
+ aug_proposals = torch.cat(recovered_proposals, dim=0)
+ merged_proposals, _ = nms(aug_proposals[:, :4].contiguous(),
+ aug_proposals[:, -1].contiguous(),
+ cfg.nms.iou_threshold)
+ scores = merged_proposals[:, 4]
+ _, order = scores.sort(0, descending=True)
+ num = min(cfg.max_per_img, merged_proposals.shape[0])
+ order = order[:num]
+ merged_proposals = merged_proposals[order, :]
+ return merged_proposals
+
+
+def merge_aug_bboxes(aug_bboxes, aug_scores, img_metas, rcnn_test_cfg):
+ """Merge augmented detection bboxes and scores.
+
+ Args:
+ aug_bboxes (list[Tensor]): shape (n, 4*#class)
+ aug_scores (list[Tensor] or None): shape (n, #class)
+ img_shapes (list[Tensor]): shape (3, ).
+ rcnn_test_cfg (dict): rcnn test config.
+
+ Returns:
+ tuple: (bboxes, scores)
+ """
+ recovered_bboxes = []
+ for bboxes, img_info in zip(aug_bboxes, img_metas):
+ img_shape = img_info[0]['img_shape']
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ flip_direction = img_info[0]['flip_direction']
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
+ flip_direction)
+ recovered_bboxes.append(bboxes)
+ bboxes = torch.stack(recovered_bboxes).mean(dim=0)
+ if aug_scores is None:
+ return bboxes
+ else:
+ scores = torch.stack(aug_scores).mean(dim=0)
+ return bboxes, scores
+
+
+def merge_aug_scores(aug_scores):
+ """Merge augmented bbox scores."""
+ if isinstance(aug_scores[0], torch.Tensor):
+ return torch.mean(torch.stack(aug_scores), dim=0)
+ else:
+ return np.mean(aug_scores, axis=0)
+
+
+def merge_aug_masks(aug_masks, img_metas, rcnn_test_cfg, weights=None):
+ """Merge augmented mask prediction.
+
+ Args:
+ aug_masks (list[ndarray]): shape (n, #class, h, w)
+ img_shapes (list[ndarray]): shape (3, ).
+ rcnn_test_cfg (dict): rcnn test config.
+
+ Returns:
+ tuple: (bboxes, scores)
+ """
+ recovered_masks = []
+ for mask, img_info in zip(aug_masks, img_metas):
+ flip = img_info[0]['flip']
+ flip_direction = img_info[0]['flip_direction']
+ if flip:
+ if flip_direction == 'horizontal':
+ mask = mask[:, :, :, ::-1]
+ elif flip_direction == 'vertical':
+ mask = mask[:, :, ::-1, :]
+ else:
+ raise ValueError(
+ f"Invalid flipping direction '{flip_direction}'")
+ recovered_masks.append(mask)
+
+ if weights is None:
+ merged_masks = np.mean(recovered_masks, axis=0)
+ else:
+ merged_masks = np.average(
+ np.array(recovered_masks), axis=0, weights=np.array(weights))
+ return merged_masks
diff --git a/mmdet/core/utils/__init__.py b/mmdet/core/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c51dac6d648f41d5c5f46dbf703f19469a7bb6c
--- /dev/null
+++ b/mmdet/core/utils/__init__.py
@@ -0,0 +1,7 @@
+from .dist_utils import DistOptimizerHook, allreduce_grads, reduce_mean
+from .misc import mask2ndarray, multi_apply, unmap
+
+__all__ = [
+ 'allreduce_grads', 'DistOptimizerHook', 'reduce_mean', 'multi_apply',
+ 'unmap', 'mask2ndarray'
+]
diff --git a/mmdet/core/utils/dist_utils.py b/mmdet/core/utils/dist_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5fe77753313783f95bd7111038ef8b58ee4e4bc5
--- /dev/null
+++ b/mmdet/core/utils/dist_utils.py
@@ -0,0 +1,69 @@
+import warnings
+from collections import OrderedDict
+
+import torch.distributed as dist
+from mmcv.runner import OptimizerHook
+from torch._utils import (_flatten_dense_tensors, _take_tensors,
+ _unflatten_dense_tensors)
+
+
+def _allreduce_coalesced(tensors, world_size, bucket_size_mb=-1):
+ if bucket_size_mb > 0:
+ bucket_size_bytes = bucket_size_mb * 1024 * 1024
+ buckets = _take_tensors(tensors, bucket_size_bytes)
+ else:
+ buckets = OrderedDict()
+ for tensor in tensors:
+ tp = tensor.type()
+ if tp not in buckets:
+ buckets[tp] = []
+ buckets[tp].append(tensor)
+ buckets = buckets.values()
+
+ for bucket in buckets:
+ flat_tensors = _flatten_dense_tensors(bucket)
+ dist.all_reduce(flat_tensors)
+ flat_tensors.div_(world_size)
+ for tensor, synced in zip(
+ bucket, _unflatten_dense_tensors(flat_tensors, bucket)):
+ tensor.copy_(synced)
+
+
+def allreduce_grads(params, coalesce=True, bucket_size_mb=-1):
+ """Allreduce gradients.
+
+ Args:
+ params (list[torch.Parameters]): List of parameters of a model
+ coalesce (bool, optional): Whether allreduce parameters as a whole.
+ Defaults to True.
+ bucket_size_mb (int, optional): Size of bucket, the unit is MB.
+ Defaults to -1.
+ """
+ grads = [
+ param.grad.data for param in params
+ if param.requires_grad and param.grad is not None
+ ]
+ world_size = dist.get_world_size()
+ if coalesce:
+ _allreduce_coalesced(grads, world_size, bucket_size_mb)
+ else:
+ for tensor in grads:
+ dist.all_reduce(tensor.div_(world_size))
+
+
+class DistOptimizerHook(OptimizerHook):
+ """Deprecated optimizer hook for distributed training."""
+
+ def __init__(self, *args, **kwargs):
+ warnings.warn('"DistOptimizerHook" is deprecated, please switch to'
+ '"mmcv.runner.OptimizerHook".')
+ super().__init__(*args, **kwargs)
+
+
+def reduce_mean(tensor):
+ """"Obtain the mean of tensor on different GPUs."""
+ if not (dist.is_available() and dist.is_initialized()):
+ return tensor
+ tensor = tensor.clone()
+ dist.all_reduce(tensor.div_(dist.get_world_size()), op=dist.ReduceOp.SUM)
+ return tensor
diff --git a/mmdet/core/utils/misc.py b/mmdet/core/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e22c7b9085317b61a25c67d361f7e70df65bed1
--- /dev/null
+++ b/mmdet/core/utils/misc.py
@@ -0,0 +1,61 @@
+from functools import partial
+
+import numpy as np
+import torch
+from six.moves import map, zip
+
+from ..mask.structures import BitmapMasks, PolygonMasks
+
+
+def multi_apply(func, *args, **kwargs):
+ """Apply function to a list of arguments.
+
+ Note:
+ This function applies the ``func`` to multiple inputs and
+ map the multiple outputs of the ``func`` into different
+ list. Each list contains the same type of outputs corresponding
+ to different inputs.
+
+ Args:
+ func (Function): A function that will be applied to a list of
+ arguments
+
+ Returns:
+ tuple(list): A tuple containing multiple list, each list contains \
+ a kind of returned results by the function
+ """
+ pfunc = partial(func, **kwargs) if kwargs else func
+ map_results = map(pfunc, *args)
+ return tuple(map(list, zip(*map_results)))
+
+
+def unmap(data, count, inds, fill=0):
+ """Unmap a subset of item (data) back to the original set of items (of size
+ count)"""
+ if data.dim() == 1:
+ ret = data.new_full((count, ), fill)
+ ret[inds.type(torch.bool)] = data
+ else:
+ new_size = (count, ) + data.size()[1:]
+ ret = data.new_full(new_size, fill)
+ ret[inds.type(torch.bool), :] = data
+ return ret
+
+
+def mask2ndarray(mask):
+ """Convert Mask to ndarray..
+
+ Args:
+ mask (:obj:`BitmapMasks` or :obj:`PolygonMasks` or
+ torch.Tensor or np.ndarray): The mask to be converted.
+
+ Returns:
+ np.ndarray: Ndarray mask of shape (n, h, w) that has been converted
+ """
+ if isinstance(mask, (BitmapMasks, PolygonMasks)):
+ mask = mask.to_ndarray()
+ elif isinstance(mask, torch.Tensor):
+ mask = mask.detach().cpu().numpy()
+ elif not isinstance(mask, np.ndarray):
+ raise TypeError(f'Unsupported {type(mask)} data type')
+ return mask
diff --git a/mmdet/core/visualization/__init__.py b/mmdet/core/visualization/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ff995c0861490941f8cfc19ebbd41a2ee7e2d65
--- /dev/null
+++ b/mmdet/core/visualization/__init__.py
@@ -0,0 +1,4 @@
+from .image import (color_val_matplotlib, imshow_det_bboxes,
+ imshow_gt_det_bboxes)
+
+__all__ = ['imshow_det_bboxes', 'imshow_gt_det_bboxes', 'color_val_matplotlib']
diff --git a/mmdet/core/visualization/image.py b/mmdet/core/visualization/image.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c5c88581d2f261e7ba203452bdb3d967e886b6d
--- /dev/null
+++ b/mmdet/core/visualization/image.py
@@ -0,0 +1,322 @@
+import matplotlib.pyplot as plt
+import mmcv
+import numpy as np
+import pycocotools.mask as mask_util
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+import cv2
+
+from ..utils import mask2ndarray
+
+EPS = 1e-2
+
+
+def color_val_matplotlib(color):
+ """Convert various input in BGR order to normalized RGB matplotlib color
+ tuples,
+
+ Args:
+ color (:obj:`Color`/str/tuple/int/ndarray): Color inputs
+
+ Returns:
+ tuple[float]: A tuple of 3 normalized floats indicating RGB channels.
+ """
+ color = mmcv.color_val(color)
+ color = [color / 255 for color in color[::-1]]
+ return tuple(color)
+
+
+def imshow_det_bboxes(img,
+ bboxes,
+ labels,
+ segms=None,
+ class_names=None,
+ score_thr=0,
+ bbox_color='green',
+ text_color='green',
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=True,
+ wait_time=0,
+ out_file=None):
+ """Draw bboxes and class labels (with scores) on an image.
+
+ Args:
+ img (str or ndarray): The image to be displayed.
+ bboxes (ndarray): Bounding boxes (with scores), shaped (n, 4) or
+ (n, 5).
+ labels (ndarray): Labels of bboxes.
+ segms (ndarray or None): Masks, shaped (n,h,w) or None
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown. Default: 0
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'
+ mask_color (str or tuple(int) or :obj:`Color`, optional):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None
+ thickness (int): Thickness of lines. Default: 2
+ font_size (int): Font size of texts. Default: 13
+ show (bool): Whether to show the image. Default: True
+ win_name (str): The window name. Default: ''
+ wait_time (float): Value of waitKey param. Default: 0.
+ out_file (str, optional): The filename to write the image.
+ Default: None
+
+ Returns:
+ ndarray: The image with bboxes drawn on it.
+ """
+ assert bboxes.ndim == 2, \
+ f' bboxes ndim should be 2, but its ndim is {bboxes.ndim}.'
+ assert labels.ndim == 1, \
+ f' labels ndim should be 1, but its ndim is {labels.ndim}.'
+ assert bboxes.shape[0] == labels.shape[0], \
+ 'bboxes.shape[0] and labels.shape[0] should have the same length.'
+ assert bboxes.shape[1] == 4 or bboxes.shape[1] == 5, \
+ f' bboxes.shape[1] should be 4 or 5, but its {bboxes.shape[1]}.'
+ img = mmcv.imread(img).astype(np.uint8)
+
+ if score_thr > 0:
+ assert bboxes.shape[1] == 5
+ scores = bboxes[:, -1]
+ inds = scores > score_thr
+ bboxes = bboxes[inds, :]
+ labels = labels[inds]
+ if segms is not None:
+ if len(inds) != len(segms):
+ inds = np.repeat(a = inds, repeats = 2)
+ segms = segms[inds, ...]
+
+ mask_colors = []
+ if labels.shape[0] > 0:
+ if mask_color is None:
+ # random color
+ np.random.seed(46)
+ mask_colors = [
+ np.random.randint(0, 256, (1, 3), dtype=np.uint8)
+ #for _ in range(max(labels) + 2)
+ for _ in range(100)
+ ]
+ #print(mask_colors)
+ #asas
+ else:
+ # specify color
+ mask_colors = [
+ np.array(mmcv.color_val(mask_color)[::-1], dtype=np.uint8)
+ ] * (
+ max(labels) + 1)
+
+ bbox_color = color_val_matplotlib(bbox_color)
+ text_color = color_val_matplotlib(text_color)
+
+ img = mmcv.bgr2rgb(img)
+ width, height = img.shape[1], img.shape[0]
+ img = np.ascontiguousarray(img)
+
+ fig = plt.figure(win_name, frameon=False)
+ plt.title(win_name)
+ canvas = fig.canvas
+ dpi = fig.get_dpi()
+ # add a small EPS to avoid precision lost due to matplotlib's truncation
+ # (https://github.com/matplotlib/matplotlib/issues/15363)
+ fig.set_size_inches((width + EPS) / dpi, (height + EPS) / dpi)
+
+ # remove white edges by set subplot margin
+ plt.subplots_adjust(left=0, right=1, bottom=0, top=1)
+ ax = plt.gca()
+ ax.axis('off')
+
+ polygons = []
+ color = []
+ img_bound =img*0
+ #img=img*0
+ for i, (bbox, label) in enumerate(zip(bboxes, labels)):
+ bbox_int = bbox.astype(np.int32)
+ poly = [[bbox_int[0], bbox_int[1]], [bbox_int[0], bbox_int[3]],
+ [bbox_int[2], bbox_int[3]], [bbox_int[2], bbox_int[1]]]
+ np_poly = np.array(poly).reshape((4, 2))
+ polygons.append(Polygon(np_poly))
+ color.append(bbox_color)
+ label_text = class_names[
+ label] if class_names is not None else f'class {label}'
+ if len(bbox) > 4:
+ label_text += f'|{bbox[-1]:.02f}'
+ '''
+ ax.text(
+ bbox_int[0],
+ bbox_int[1],
+ f'{label_text}',
+ bbox={
+ 'facecolor': 'black',
+ 'alpha': 0.8,
+ 'pad': 0.7,
+ 'edgecolor': 'none'
+ },
+ color=text_color,
+ fontsize=font_size,
+ verticalalignment='top',
+ horizontalalignment='left')
+ '''
+ if segms is not None:
+ for ll in range(1):
+ color_mask = mask_colors[np.random.randint(0, 99)]
+ mask = segms[len(labels)*ll+i].astype(bool)
+ show_border = True
+ img[mask] = img[mask] * 0.5 + color_mask * 0.5
+ if show_border:
+ contours,_ = cv2.findContours(mask.copy().astype('uint8'), cv2.RETR_CCOMP, cv2.CHAIN_APPROX_NONE)
+ border_thick = min(int(4*(max(bbox_int[2]-bbox_int[0],bbox_int[3]-bbox_int[1])/300))+1,6)
+ cv2.drawContours(img, contours, -1, (int(color_mask[0][0]),int(color_mask[0][1]),int(color_mask[0][2])), border_thick)
+ #img = cv2.addWeighted(img,1.0,img_bound,1.0,0)
+
+ #img[img_bound>0] = img_bound
+
+ plt.imshow(img)
+
+ p = PatchCollection(
+ polygons, facecolor='none', edgecolors=color, linewidths=thickness)
+ #ax.add_collection(p)
+
+ stream, _ = canvas.print_to_buffer()
+ buffer = np.frombuffer(stream, dtype='uint8')
+ img_rgba = buffer.reshape(height, width, 4)
+ rgb, alpha = np.split(img_rgba, [3], axis=2)
+ img = rgb.astype('uint8')
+ img = mmcv.rgb2bgr(img)
+
+ if show:
+ # We do not use cv2 for display because in some cases, opencv will
+ # conflict with Qt, it will output a warning: Current thread
+ # is not the object's thread. You can refer to
+ # https://github.com/opencv/opencv-python/issues/46 for details
+ if wait_time == 0:
+ plt.show()
+ else:
+ plt.show(block=False)
+ plt.pause(wait_time)
+ if out_file is not None:
+ mmcv.imwrite(img, out_file)
+
+ plt.close()
+
+ return img
+
+
+def imshow_gt_det_bboxes(img,
+ annotation,
+ result,
+ class_names=None,
+ score_thr=0,
+ gt_bbox_color=(255, 102, 61),
+ gt_text_color=(255, 102, 61),
+ gt_mask_color=(255, 102, 61),
+ det_bbox_color=(72, 101, 241),
+ det_text_color=(72, 101, 241),
+ det_mask_color=(72, 101, 241),
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=True,
+ wait_time=0,
+ out_file=None):
+ """General visualization GT and result function.
+
+ Args:
+ img (str or ndarray): The image to be displayed.)
+ annotation (dict): Ground truth annotations where contain keys of
+ 'gt_bboxes' and 'gt_labels' or 'gt_masks'
+ result (tuple[list] or list): The detection result, can be either
+ (bbox, segm) or just bbox.
+ class_names (list[str]): Names of each classes.
+ score_thr (float): Minimum score of bboxes to be shown. Default: 0
+ gt_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: (255, 102, 61)
+ gt_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: (255, 102, 61)
+ gt_mask_color (str or tuple(int) or :obj:`Color`, optional):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: (255, 102, 61)
+ det_bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: (72, 101, 241)
+ det_text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: (72, 101, 241)
+ det_mask_color (str or tuple(int) or :obj:`Color`, optional):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: (72, 101, 241)
+ thickness (int): Thickness of lines. Default: 2
+ font_size (int): Font size of texts. Default: 13
+ win_name (str): The window name. Default: ''
+ show (bool): Whether to show the image. Default: True
+ wait_time (float): Value of waitKey param. Default: 0.
+ out_file (str, optional): The filename to write the image.
+ Default: None
+
+ Returns:
+ ndarray: The image with bboxes or masks drawn on it.
+ """
+ assert 'gt_bboxes' in annotation
+ assert 'gt_labels' in annotation
+ assert isinstance(
+ result,
+ (tuple, list)), f'Expected tuple or list, but get {type(result)}'
+
+ gt_masks = annotation.get('gt_masks', None)
+ if gt_masks is not None:
+ gt_masks = mask2ndarray(gt_masks)
+
+ img = mmcv.imread(img)
+
+ img = imshow_det_bboxes(
+ img,
+ annotation['gt_bboxes'],
+ annotation['gt_labels'],
+ gt_masks,
+ class_names=class_names,
+ bbox_color=gt_bbox_color,
+ text_color=gt_text_color,
+ mask_color=gt_mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=False)
+
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ else:
+ bbox_result, segm_result = result, None
+
+ bboxes = np.vstack(bbox_result)
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+
+ segms = None
+ if segm_result is not None and len(labels) > 0: # non empty
+ segms = mmcv.concat_list(segm_result)
+ segms = mask_util.decode(segms)
+ segms = segms.transpose(2, 0, 1)
+
+ img = imshow_det_bboxes(
+ img,
+ bboxes,
+ labels,
+ segms=segms,
+ class_names=class_names,
+ score_thr=score_thr,
+ bbox_color=det_bbox_color,
+ text_color=det_text_color,
+ mask_color=det_mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+ return img
diff --git a/mmdet/datasets/__init__.py b/mmdet/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b18b30a258c32283cbfc03ba01781a19fd993c1
--- /dev/null
+++ b/mmdet/datasets/__init__.py
@@ -0,0 +1,24 @@
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from .cityscapes import CityscapesDataset
+from .coco import CocoDataset
+from .custom import CustomDataset
+from .dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
+ RepeatDataset)
+from .deepfashion import DeepFashionDataset
+from .lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
+from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
+from .utils import (NumClassCheckHook, get_loading_pipeline,
+ replace_ImageToTensor)
+from .voc import VOCDataset
+from .wider_face import WIDERFaceDataset
+from .xml_style import XMLDataset
+
+__all__ = [
+ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
+ 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
+ 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
+ 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'ClassBalancedDataset', 'WIDERFaceDataset', 'DATASETS', 'PIPELINES',
+ 'build_dataset', 'replace_ImageToTensor', 'get_loading_pipeline',
+ 'NumClassCheckHook'
+]
diff --git a/mmdet/datasets/builder.py b/mmdet/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9466a517dee746a6677b27a19713f2e89ed7194
--- /dev/null
+++ b/mmdet/datasets/builder.py
@@ -0,0 +1,143 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from mmcv.parallel import collate
+from mmcv.runner import get_dist_info
+from mmcv.utils import Registry, build_from_cfg
+from torch.utils.data import DataLoader
+
+from .samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ hard_limit = rlimit[1]
+ soft_limit = min(4096, hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ from .dataset_wrappers import ConcatDataset
+ ann_files = cfg['ann_file']
+ img_prefixes = cfg.get('img_prefix', None)
+ seg_prefixes = cfg.get('seg_prefix', None)
+ proposal_files = cfg.get('proposal_file', None)
+ separate_eval = cfg.get('separate_eval', True)
+
+ datasets = []
+ num_dset = len(ann_files)
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ # pop 'separate_eval' since it is not a valid key for common datasets.
+ if 'separate_eval' in data_cfg:
+ data_cfg.pop('separate_eval')
+ data_cfg['ann_file'] = ann_files[i]
+ if isinstance(img_prefixes, (list, tuple)):
+ data_cfg['img_prefix'] = img_prefixes[i]
+ if isinstance(seg_prefixes, (list, tuple)):
+ data_cfg['seg_prefix'] = seg_prefixes[i]
+ if isinstance(proposal_files, (list, tuple)):
+ data_cfg['proposal_file'] = proposal_files[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets, separate_eval)
+
+
+def build_dataset(cfg, default_args=None):
+ from .dataset_wrappers import (ConcatDataset, RepeatDataset,
+ ClassBalancedDataset)
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'ConcatDataset':
+ dataset = ConcatDataset(
+ [build_dataset(c, default_args) for c in cfg['datasets']],
+ cfg.get('separate_eval', True))
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif cfg['type'] == 'ClassBalancedDataset':
+ dataset = ClassBalancedDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
+ elif isinstance(cfg.get('ann_file'), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ # DistributedGroupSampler will definitely shuffle the data to satisfy
+ # that images on each GPU are in the same group
+ if shuffle:
+ sampler = DistributedGroupSampler(
+ dataset, samples_per_gpu, world_size, rank, seed=seed)
+ else:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=False, seed=seed)
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=False,
+ worker_init_fn=init_fn,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/mmdet/datasets/cityscapes.py b/mmdet/datasets/cityscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..71eead87e7f4e511c0cb59e69c3a599832ada0e4
--- /dev/null
+++ b/mmdet/datasets/cityscapes.py
@@ -0,0 +1,334 @@
+# Modified from https://github.com/facebookresearch/detectron2/blob/master/detectron2/data/datasets/cityscapes.py # noqa
+# and https://github.com/mcordts/cityscapesScripts/blob/master/cityscapesscripts/evaluation/evalInstanceLevelSemanticLabeling.py # noqa
+
+import glob
+import os
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+from mmcv.utils import print_log
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class CityscapesDataset(CocoDataset):
+
+ CLASSES = ('person', 'rider', 'car', 'truck', 'bus', 'train', 'motorcycle',
+ 'bicycle')
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = img_info['id']
+ ann_ids = self.coco.getAnnIds(imgIds=[img_id])
+ ann_info = self.coco.loadAnns(ann_ids)
+ all_iscrowd = all([_['iscrowd'] for _ in ann_info])
+ if self.filter_empty_gt and (self.img_ids[i] not in ids_in_cat
+ or all_iscrowd):
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ img_info (dict): Image info of an image.
+ ann_info (list[dict]): Annotation info of an image.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, \
+ bboxes_ignore, labels, masks, seg_map. \
+ "masks" are already decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann['segmentation'])
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=img_info['segm_file'])
+
+ return ann
+
+ def results2txt(self, results, outfile_prefix):
+ """Dump the detection results to a txt file.
+
+ Args:
+ results (list[list | tuple]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files.
+ If the prefix is "somepath/xxx",
+ the txt files will be named "somepath/xxx.txt".
+
+ Returns:
+ list[str]: Result txt files which contains corresponding \
+ instance segmentation images.
+ """
+ try:
+ import cityscapesscripts.helpers.labels as CSLabels
+ except ImportError:
+ raise ImportError('Please run "pip install citscapesscripts" to '
+ 'install cityscapesscripts first.')
+ result_files = []
+ os.makedirs(outfile_prefix, exist_ok=True)
+ prog_bar = mmcv.ProgressBar(len(self))
+ for idx in range(len(self)):
+ result = results[idx]
+ filename = self.data_infos[idx]['filename']
+ basename = osp.splitext(osp.basename(filename))[0]
+ pred_txt = osp.join(outfile_prefix, basename + '_pred.txt')
+
+ bbox_result, segm_result = result
+ bboxes = np.vstack(bbox_result)
+ # segm results
+ if isinstance(segm_result, tuple):
+ # Some detectors use different scores for bbox and mask,
+ # like Mask Scoring R-CNN. Score of segm will be used instead
+ # of bbox score.
+ segms = mmcv.concat_list(segm_result[0])
+ mask_score = segm_result[1]
+ else:
+ # use bbox score for mask score
+ segms = mmcv.concat_list(segm_result)
+ mask_score = [bbox[-1] for bbox in bboxes]
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+
+ assert len(bboxes) == len(segms) == len(labels)
+ num_instances = len(bboxes)
+ prog_bar.update()
+ with open(pred_txt, 'w') as fout:
+ for i in range(num_instances):
+ pred_class = labels[i]
+ classes = self.CLASSES[pred_class]
+ class_id = CSLabels.name2label[classes].id
+ score = mask_score[i]
+ mask = maskUtils.decode(segms[i]).astype(np.uint8)
+ png_filename = osp.join(outfile_prefix,
+ basename + f'_{i}_{classes}.png')
+ mmcv.imwrite(mask, png_filename)
+ fout.write(f'{osp.basename(png_filename)} {class_id} '
+ f'{score}\n')
+ result_files.append(pred_txt)
+
+ return result_files
+
+ def format_results(self, results, txtfile_prefix=None):
+ """Format the results to txt (standard format for Cityscapes
+ evaluation).
+
+ Args:
+ results (list): Testing results of the dataset.
+ txtfile_prefix (str | None): The prefix of txt files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving txt/png files when txtfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if txtfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ txtfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2txt(results, txtfile_prefix)
+
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ outfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=np.arange(0.5, 0.96, 0.05)):
+ """Evaluation in Cityscapes/COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ outfile_prefix (str | None): The prefix of output file. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If results are evaluated with COCO protocol, it would be the
+ prefix of output json file. For example, the metric is 'bbox'
+ and 'segm', then json files would be "a/b/prefix.bbox.json" and
+ "a/b/prefix.segm.json".
+ If results are evaluated with cityscapes protocol, it would be
+ the prefix of output txt/png files. The output files would be
+ png images under folder "a/b/prefix/xxx/" and the file name of
+ images would be written into a txt file
+ "a/b/prefix/xxx_pred.txt", where "xxx" is the video name of
+ cityscapes. If not specified, a temp file will be created.
+ Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float]): IoU threshold used for evaluating
+ recalls. If set to a list, the average recall of all IoUs will
+ also be computed. Default: 0.5.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric or cityscapes mAP \
+ and AP@50.
+ """
+ eval_results = dict()
+
+ metrics = metric.copy() if isinstance(metric, list) else [metric]
+
+ if 'cityscapes' in metrics:
+ eval_results.update(
+ self._evaluate_cityscapes(results, outfile_prefix, logger))
+ metrics.remove('cityscapes')
+
+ # left metrics are all coco metric
+ if len(metrics) > 0:
+ # create CocoDataset with CityscapesDataset annotation
+ self_coco = CocoDataset(self.ann_file, self.pipeline.transforms,
+ None, self.data_root, self.img_prefix,
+ self.seg_prefix, self.proposal_file,
+ self.test_mode, self.filter_empty_gt)
+ # TODO: remove this in the future
+ # reload annotations of correct class
+ self_coco.CLASSES = self.CLASSES
+ self_coco.data_infos = self_coco.load_annotations(self.ann_file)
+ eval_results.update(
+ self_coco.evaluate(results, metrics, logger, outfile_prefix,
+ classwise, proposal_nums, iou_thrs))
+
+ return eval_results
+
+ def _evaluate_cityscapes(self, results, txtfile_prefix, logger):
+ """Evaluation in Cityscapes protocol.
+
+ Args:
+ results (list): Testing results of the dataset.
+ txtfile_prefix (str | None): The prefix of output txt file
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str: float]: Cityscapes evaluation results, contains 'mAP' \
+ and 'AP@50'.
+ """
+
+ try:
+ import cityscapesscripts.evaluation.evalInstanceLevelSemanticLabeling as CSEval # noqa
+ except ImportError:
+ raise ImportError('Please run "pip install citscapesscripts" to '
+ 'install cityscapesscripts first.')
+ msg = 'Evaluating in Cityscapes style'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ result_files, tmp_dir = self.format_results(results, txtfile_prefix)
+
+ if tmp_dir is None:
+ result_dir = osp.join(txtfile_prefix, 'results')
+ else:
+ result_dir = osp.join(tmp_dir.name, 'results')
+
+ eval_results = OrderedDict()
+ print_log(f'Evaluating results under {result_dir} ...', logger=logger)
+
+ # set global states in cityscapes evaluation API
+ CSEval.args.cityscapesPath = os.path.join(self.img_prefix, '../..')
+ CSEval.args.predictionPath = os.path.abspath(result_dir)
+ CSEval.args.predictionWalk = None
+ CSEval.args.JSONOutput = False
+ CSEval.args.colorized = False
+ CSEval.args.gtInstancesFile = os.path.join(result_dir,
+ 'gtInstances.json')
+ CSEval.args.groundTruthSearch = os.path.join(
+ self.img_prefix.replace('leftImg8bit', 'gtFine'),
+ '*/*_gtFine_instanceIds.png')
+
+ groundTruthImgList = glob.glob(CSEval.args.groundTruthSearch)
+ assert len(groundTruthImgList), 'Cannot find ground truth images' \
+ f' in {CSEval.args.groundTruthSearch}.'
+ predictionImgList = []
+ for gt in groundTruthImgList:
+ predictionImgList.append(CSEval.getPrediction(gt, CSEval.args))
+ CSEval_results = CSEval.evaluateImgLists(predictionImgList,
+ groundTruthImgList,
+ CSEval.args)['averages']
+
+ eval_results['mAP'] = CSEval_results['allAp']
+ eval_results['AP@50'] = CSEval_results['allAp50%']
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/mmdet/datasets/coco.py b/mmdet/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef698e5323971601c0985fb27322d7501ded8159
--- /dev/null
+++ b/mmdet/datasets/coco.py
@@ -0,0 +1,548 @@
+import itertools
+import logging
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools
+from mmcv.utils import print_log
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from terminaltables import AsciiTable
+
+from mmdet.core import eval_recalls
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class CocoDataset(CustomDataset):
+
+ CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
+ 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+ if not getattr(pycocotools, '__version__', '0') >= '12.0.2':
+ raise AssertionError(
+ 'Incompatible version of pycocotools is installed. '
+ 'Run pip uninstall pycocotools first. Then run pip '
+ 'install mmpycocotools to install open-mmlab forked '
+ 'pycocotools.')
+
+ self.coco = COCO(ann_file)
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
+
+ def get_ann_info(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return self._parse_ann_info(self.data_infos[idx], ann_info)
+
+ def get_cat_ids(self, idx):
+ """Get COCO category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return [ann['category_id'] for ann in ann_info]
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = self.img_ids[i]
+ if self.filter_empty_gt and img_id not in ids_in_cat:
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ if inter_w * inter_h == 0:
+ continue
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ gt_masks_ann.append(ann.get('segmentation', None))
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].replace('jpg', 'png')
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def xyxy2xywh(self, bbox):
+ """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
+ evaluation.
+
+ Args:
+ bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
+ ``xyxy`` order.
+
+ Returns:
+ list[float]: The converted bounding boxes, in ``xywh`` order.
+ """
+
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0],
+ _bbox[3] - _bbox[1],
+ ]
+
+ def _proposal2json(self, results):
+ """Convert proposal results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+ def _det2json(self, results):
+ """Convert detection results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ result = results[idx]
+ for label in range(len(result)):
+ bboxes = result[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+ def _segm2json(self, results):
+ """Convert instance segmentation results to COCO json style."""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ det, seg = results[idx]
+ for label in range(len(det)):
+ # bbox results
+ bboxes = det[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ bbox_json_results.append(data)
+
+ # segm results
+ # some detectors use different scores for bbox and mask
+ if isinstance(seg, tuple):
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(mask_score[i])
+ data['category_id'] = self.cat_ids[label]
+ if isinstance(segms[i]['counts'], bytes):
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the detection results to a COCO style json file.
+
+ There are 3 types of results: proposals, bbox predictions, mask
+ predictions, and they have different data types. This method will
+ automatically recognize the type, and dump them to json files.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
+ "somepath/xxx.proposal.json".
+
+ Returns:
+ dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
+ values are corresponding filenames.
+ """
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = self._det2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = self._segm2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = self._proposal2json(results)
+ result_files['proposal'] = f'{outfile_prefix}.proposal.json'
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
+
+ def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
+ gt_bboxes = []
+ for i in range(len(self.img_ids)):
+ ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
+ ann_info = self.coco.load_anns(ann_ids)
+ if len(ann_info) == 0:
+ gt_bboxes.append(np.zeros((0, 4)))
+ continue
+ bboxes = []
+ for ann in ann_info:
+ if ann.get('ignore', False) or ann['iscrowd']:
+ continue
+ x1, y1, w, h = ann['bbox']
+ bboxes.append([x1, y1, x1 + w, y1 + h])
+ bboxes = np.array(bboxes, dtype=np.float32)
+ if bboxes.shape[0] == 0:
+ bboxes = np.zeros((0, 4))
+ gt_bboxes.append(bboxes)
+
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
+ ar = recalls.mean(axis=1)
+ return ar
+
+ def format_results(self, results, jsonfile_prefix=None, **kwargs):
+ """Format the results to json (standard format for COCO evaluation).
+
+ Args:
+ results (list[tuple | numpy.ndarray]): Testing results of the
+ dataset.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving json files when jsonfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Evaluation in COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ if iou_thrs is None:
+ iou_thrs = np.linspace(
+ .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ if metric_items is not None:
+ if not isinstance(metric_items, list):
+ metric_items = [metric_items]
+
+ #result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ cocoGt = self.coco
+ print(cocoGt['images'])
+ asas
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files:
+ raise KeyError(f'{metric} is not in results')
+ try:
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
+ cocoEval.params.catIds = self.cat_ids
+ cocoEval.params.imgIds = self.img_ids
+ cocoEval.params.maxDets = list(proposal_nums)
+ cocoEval.params.iouThrs = iou_thrs
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item {metric_item} is not supported')
+
+ if metric == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
+ eval_results[item] = val
+ else:
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = cocoEval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.loadCats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
+ )
+ eval_results[key] = val
+ ap = cocoEval.stats[:6]
+ eval_results[f'{metric}_mAP_copypaste'] = (
+ f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
+ f'{ap[4]:.3f} {ap[5]:.3f}')
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/mmdet/datasets/custom.py b/mmdet/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..356f01ede6456312920b6fe8fa618258d8898075
--- /dev/null
+++ b/mmdet/datasets/custom.py
@@ -0,0 +1,334 @@
+import os.path as osp
+import warnings
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from mmdet.core import eval_map, eval_recalls
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDataset(Dataset):
+ """Custom dataset for detection.
+
+ The annotation format is shown as follows. The `ann` field is optional for
+ testing.
+
+ .. code-block:: none
+
+ [
+ {
+ 'filename': 'a.jpg',
+ 'width': 1280,
+ 'height': 720,
+ 'ann': {
+ 'bboxes': (n, 4) in (x1, y1, x2, y2) order.
+ 'labels': (n, ),
+ 'bboxes_ignore': (k, 4), (optional field)
+ 'labels_ignore': (k, 4) (optional field)
+ }
+ },
+ ...
+ ]
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ data_root (str, optional): Data root for ``ann_file``,
+ ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
+ test_mode (bool, optional): If set True, annotation will not be loaded.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes of the dataset's classes will be filtered out. This option
+ only works when `test_mode=False`, i.e., we never filter images
+ during tests.
+ """
+
+ CLASSES = None
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True):
+ self.ann_file = ann_file
+ self.data_root = data_root
+ self.img_prefix = img_prefix
+ self.seg_prefix = seg_prefix
+ self.proposal_file = proposal_file
+ self.test_mode = test_mode
+ self.filter_empty_gt = filter_empty_gt
+ self.CLASSES = self.get_classes(classes)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.ann_file):
+ self.ann_file = osp.join(self.data_root, self.ann_file)
+ if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
+ self.img_prefix = osp.join(self.data_root, self.img_prefix)
+ if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
+ self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
+ if not (self.proposal_file is None
+ or osp.isabs(self.proposal_file)):
+ self.proposal_file = osp.join(self.data_root,
+ self.proposal_file)
+ # load annotations (and proposals)
+ self.data_infos = self.load_annotations(self.ann_file)
+
+ if self.proposal_file is not None:
+ self.proposals = self.load_proposals(self.proposal_file)
+ else:
+ self.proposals = None
+
+ # filter images too small and containing no annotations
+ if not test_mode:
+ valid_inds = self._filter_imgs()
+ self.data_infos = [self.data_infos[i] for i in valid_inds]
+ if self.proposals is not None:
+ self.proposals = [self.proposals[i] for i in valid_inds]
+ # set group flag for the sampler
+ self._set_group_flag()
+
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.data_infos)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from annotation file."""
+ return mmcv.load(ann_file)
+
+ def load_proposals(self, proposal_file):
+ """Load proposal from proposal file."""
+ return mmcv.load(proposal_file)
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.data_infos[idx]['ann']
+
+ def get_cat_ids(self, idx):
+ """Get category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+ results['seg_prefix'] = self.seg_prefix
+ results['proposal_file'] = self.proposal_file
+ results['bbox_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small."""
+ if self.filter_empty_gt:
+ warnings.warn(
+ 'CustomDataset does not support filtering empty gt images.')
+ valid_inds = []
+ for i, img_info in enumerate(self.data_infos):
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ return valid_inds
+
+ def _set_group_flag(self):
+ """Set flag according to image aspect ratio.
+
+ Images with aspect ratio greater than 1 will be set as group 1,
+ otherwise group 0.
+ """
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+ for i in range(len(self)):
+ img_info = self.data_infos[i]
+ if img_info['width'] / img_info['height'] > 1:
+ self.flag[i] = 1
+
+ def _rand_another(self, idx):
+ """Get another random index from the same group as the given index."""
+ pool = np.where(self.flag == self.flag[idx])[0]
+ return np.random.choice(pool)
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set \
+ True).
+ """
+
+ if self.test_mode:
+ while 1:
+ try:
+ return self.prepare_test_img(idx)
+ except:
+ idx = idx+1
+ #return self.prepare_test_img(idx+1)
+
+ #return self.prepare_test_img(idx)
+ while True:
+ try:
+ data = self.prepare_train_img(idx)
+ except:
+ data = self.prepare_train_img(idx-1)
+
+ if data is None:
+ idx = self._rand_another(idx)
+ continue
+ return data
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by \
+ pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ results = dict(img_info=img_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ @classmethod
+ def get_classes(cls, classes=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+
+ Returns:
+ tuple[str] or list[str]: Names of categories of the dataset.
+ """
+ if classes is None:
+ return cls.CLASSES
+
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ return class_names
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ proposal_nums=(100, 300, 1000),
+ iou_thr=0.5,
+ scale_ranges=None):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
+ Default: None.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP', 'recall']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list)
+ mean_aps = []
+ for iou_thr in iou_thrs:
+ print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=scale_ranges,
+ iou_thr=iou_thr,
+ dataset=self.CLASSES,
+ logger=logger)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ elif metric == 'recall':
+ gt_bboxes = [ann['bboxes'] for ann in annotations]
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
+ for i, num in enumerate(proposal_nums):
+ for j, iou in enumerate(iou_thrs):
+ eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
+ if recalls.shape[1] > 1:
+ ar = recalls.mean(axis=1)
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ return eval_results
diff --git a/mmdet/datasets/dataset_wrappers.py b/mmdet/datasets/dataset_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..55ad5cb60e581a96bdbd1fbbeebc2f46f8c4e899
--- /dev/null
+++ b/mmdet/datasets/dataset_wrappers.py
@@ -0,0 +1,282 @@
+import bisect
+import math
+from collections import defaultdict
+
+import numpy as np
+from mmcv.utils import print_log
+from torch.utils.data.dataset import ConcatDataset as _ConcatDataset
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class ConcatDataset(_ConcatDataset):
+ """A wrapper of concatenated dataset.
+
+ Same as :obj:`torch.utils.data.dataset.ConcatDataset`, but
+ concat the group flag for image aspect ratio.
+
+ Args:
+ datasets (list[:obj:`Dataset`]): A list of datasets.
+ separate_eval (bool): Whether to evaluate the results
+ separately if it is used as validation dataset.
+ Defaults to True.
+ """
+
+ def __init__(self, datasets, separate_eval=True):
+ super(ConcatDataset, self).__init__(datasets)
+ self.CLASSES = datasets[0].CLASSES
+ self.separate_eval = separate_eval
+ if not separate_eval:
+ if any([isinstance(ds, CocoDataset) for ds in datasets]):
+ raise NotImplementedError(
+ 'Evaluating concatenated CocoDataset as a whole is not'
+ ' supported! Please set "separate_eval=True"')
+ elif len(set([type(ds) for ds in datasets])) != 1:
+ raise NotImplementedError(
+ 'All the datasets should have same types')
+
+ if hasattr(datasets[0], 'flag'):
+ flags = []
+ for i in range(0, len(datasets)):
+ flags.append(datasets[i].flag)
+ self.flag = np.concatenate(flags)
+
+ def get_cat_ids(self, idx):
+ """Get category ids of concatenated dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ if idx < 0:
+ if -idx > len(self):
+ raise ValueError(
+ 'absolute value of index should not exceed dataset length')
+ idx = len(self) + idx
+ dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx)
+ if dataset_idx == 0:
+ sample_idx = idx
+ else:
+ sample_idx = idx - self.cumulative_sizes[dataset_idx - 1]
+ return self.datasets[dataset_idx].get_cat_ids(sample_idx)
+
+ def evaluate(self, results, logger=None, **kwargs):
+ """Evaluate the results.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+
+ Returns:
+ dict[str: float]: AP results of the total dataset or each separate
+ dataset if `self.separate_eval=True`.
+ """
+ assert len(results) == self.cumulative_sizes[-1], \
+ ('Dataset and results have different sizes: '
+ f'{self.cumulative_sizes[-1]} v.s. {len(results)}')
+
+ # Check whether all the datasets support evaluation
+ for dataset in self.datasets:
+ assert hasattr(dataset, 'evaluate'), \
+ f'{type(dataset)} does not implement evaluate function'
+
+ if self.separate_eval:
+ dataset_idx = -1
+ total_eval_results = dict()
+ for size, dataset in zip(self.cumulative_sizes, self.datasets):
+ start_idx = 0 if dataset_idx == -1 else \
+ self.cumulative_sizes[dataset_idx]
+ end_idx = self.cumulative_sizes[dataset_idx + 1]
+
+ results_per_dataset = results[start_idx:end_idx]
+ print_log(
+ f'\nEvaluateing {dataset.ann_file} with '
+ f'{len(results_per_dataset)} images now',
+ logger=logger)
+
+ eval_results_per_dataset = dataset.evaluate(
+ results_per_dataset, logger=logger, **kwargs)
+ dataset_idx += 1
+ for k, v in eval_results_per_dataset.items():
+ total_eval_results.update({f'{dataset_idx}_{k}': v})
+
+ return total_eval_results
+ elif any([isinstance(ds, CocoDataset) for ds in self.datasets]):
+ raise NotImplementedError(
+ 'Evaluating concatenated CocoDataset as a whole is not'
+ ' supported! Please set "separate_eval=True"')
+ elif len(set([type(ds) for ds in self.datasets])) != 1:
+ raise NotImplementedError(
+ 'All the datasets should have same types')
+ else:
+ original_data_infos = self.datasets[0].data_infos
+ self.datasets[0].data_infos = sum(
+ [dataset.data_infos for dataset in self.datasets], [])
+ eval_results = self.datasets[0].evaluate(
+ results, logger=logger, **kwargs)
+ self.datasets[0].data_infos = original_data_infos
+ return eval_results
+
+
+@DATASETS.register_module()
+class RepeatDataset(object):
+ """A wrapper of repeated dataset.
+
+ The length of repeated dataset will be `times` larger than the original
+ dataset. This is useful when the data loading time is long but the dataset
+ is small. Using RepeatDataset can reduce the data loading time between
+ epochs.
+
+ Args:
+ dataset (:obj:`Dataset`): The dataset to be repeated.
+ times (int): Repeat times.
+ """
+
+ def __init__(self, dataset, times):
+ self.dataset = dataset
+ self.times = times
+ self.CLASSES = dataset.CLASSES
+ if hasattr(self.dataset, 'flag'):
+ self.flag = np.tile(self.dataset.flag, times)
+
+ self._ori_len = len(self.dataset)
+
+ def __getitem__(self, idx):
+ return self.dataset[idx % self._ori_len]
+
+ def get_cat_ids(self, idx):
+ """Get category ids of repeat dataset by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ return self.dataset.get_cat_ids(idx % self._ori_len)
+
+ def __len__(self):
+ """Length after repetition."""
+ return self.times * self._ori_len
+
+
+# Modified from https://github.com/facebookresearch/detectron2/blob/41d475b75a230221e21d9cac5d69655e3415e3a4/detectron2/data/samplers/distributed_sampler.py#L57 # noqa
+@DATASETS.register_module()
+class ClassBalancedDataset(object):
+ """A wrapper of repeated dataset with repeat factor.
+
+ Suitable for training on class imbalanced datasets like LVIS. Following
+ the sampling strategy in the `paper `_,
+ in each epoch, an image may appear multiple times based on its
+ "repeat factor".
+ The repeat factor for an image is a function of the frequency the rarest
+ category labeled in that image. The "frequency of category c" in [0, 1]
+ is defined by the fraction of images in the training set (without repeats)
+ in which category c appears.
+ The dataset needs to instantiate :func:`self.get_cat_ids` to support
+ ClassBalancedDataset.
+
+ The repeat factor is computed as followed.
+
+ 1. For each category c, compute the fraction # of images
+ that contain it: :math:`f(c)`
+ 2. For each category c, compute the category-level repeat factor:
+ :math:`r(c) = max(1, sqrt(t/f(c)))`
+ 3. For each image I, compute the image-level repeat factor:
+ :math:`r(I) = max_{c in I} r(c)`
+
+ Args:
+ dataset (:obj:`CustomDataset`): The dataset to be repeated.
+ oversample_thr (float): frequency threshold below which data is
+ repeated. For categories with ``f_c >= oversample_thr``, there is
+ no oversampling. For categories with ``f_c < oversample_thr``, the
+ degree of oversampling following the square-root inverse frequency
+ heuristic above.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes will not be oversampled. Otherwise, they will be categorized
+ as the pure background class and involved into the oversampling.
+ Default: True.
+ """
+
+ def __init__(self, dataset, oversample_thr, filter_empty_gt=True):
+ self.dataset = dataset
+ self.oversample_thr = oversample_thr
+ self.filter_empty_gt = filter_empty_gt
+ self.CLASSES = dataset.CLASSES
+
+ repeat_factors = self._get_repeat_factors(dataset, oversample_thr)
+ repeat_indices = []
+ for dataset_idx, repeat_factor in enumerate(repeat_factors):
+ repeat_indices.extend([dataset_idx] * math.ceil(repeat_factor))
+ self.repeat_indices = repeat_indices
+
+ flags = []
+ if hasattr(self.dataset, 'flag'):
+ for flag, repeat_factor in zip(self.dataset.flag, repeat_factors):
+ flags.extend([flag] * int(math.ceil(repeat_factor)))
+ assert len(flags) == len(repeat_indices)
+ self.flag = np.asarray(flags, dtype=np.uint8)
+
+ def _get_repeat_factors(self, dataset, repeat_thr):
+ """Get repeat factor for each images in the dataset.
+
+ Args:
+ dataset (:obj:`CustomDataset`): The dataset
+ repeat_thr (float): The threshold of frequency. If an image
+ contains the categories whose frequency below the threshold,
+ it would be repeated.
+
+ Returns:
+ list[float]: The repeat factors for each images in the dataset.
+ """
+
+ # 1. For each category c, compute the fraction # of images
+ # that contain it: f(c)
+ category_freq = defaultdict(int)
+ num_images = len(dataset)
+ for idx in range(num_images):
+ cat_ids = set(self.dataset.get_cat_ids(idx))
+ if len(cat_ids) == 0 and not self.filter_empty_gt:
+ cat_ids = set([len(self.CLASSES)])
+ for cat_id in cat_ids:
+ category_freq[cat_id] += 1
+ for k, v in category_freq.items():
+ category_freq[k] = v / num_images
+
+ # 2. For each category c, compute the category-level repeat factor:
+ # r(c) = max(1, sqrt(t/f(c)))
+ category_repeat = {
+ cat_id: max(1.0, math.sqrt(repeat_thr / cat_freq))
+ for cat_id, cat_freq in category_freq.items()
+ }
+
+ # 3. For each image I, compute the image-level repeat factor:
+ # r(I) = max_{c in I} r(c)
+ repeat_factors = []
+ for idx in range(num_images):
+ cat_ids = set(self.dataset.get_cat_ids(idx))
+ if len(cat_ids) == 0 and not self.filter_empty_gt:
+ cat_ids = set([len(self.CLASSES)])
+ repeat_factor = 1
+ if len(cat_ids) > 0:
+ repeat_factor = max(
+ {category_repeat[cat_id]
+ for cat_id in cat_ids})
+ repeat_factors.append(repeat_factor)
+
+ return repeat_factors
+
+ def __getitem__(self, idx):
+ ori_index = self.repeat_indices[idx]
+ return self.dataset[ori_index]
+
+ def __len__(self):
+ """Length after repetition."""
+ return len(self.repeat_indices)
diff --git a/mmdet/datasets/deepfashion.py b/mmdet/datasets/deepfashion.py
new file mode 100644
index 0000000000000000000000000000000000000000..1125376091f2d4ee6843ae4f2156b3b0453be369
--- /dev/null
+++ b/mmdet/datasets/deepfashion.py
@@ -0,0 +1,10 @@
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class DeepFashionDataset(CocoDataset):
+
+ CLASSES = ('top', 'skirt', 'leggings', 'dress', 'outer', 'pants', 'bag',
+ 'neckwear', 'headwear', 'eyeglass', 'belt', 'footwear', 'hair',
+ 'skin', 'face')
diff --git a/mmdet/datasets/lvis.py b/mmdet/datasets/lvis.py
new file mode 100644
index 0000000000000000000000000000000000000000..122c64e79cf5f060d7ceddf4ad29c4debe40944b
--- /dev/null
+++ b/mmdet/datasets/lvis.py
@@ -0,0 +1,742 @@
+import itertools
+import logging
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import numpy as np
+from mmcv.utils import print_log
+from terminaltables import AsciiTable
+
+from .builder import DATASETS
+from .coco import CocoDataset
+
+
+@DATASETS.register_module()
+class LVISV05Dataset(CocoDataset):
+
+ CLASSES = (
+ 'acorn', 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock',
+ 'alcohol', 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet',
+ 'antenna', 'apple', 'apple_juice', 'applesauce', 'apricot', 'apron',
+ 'aquarium', 'armband', 'armchair', 'armoire', 'armor', 'artichoke',
+ 'trash_can', 'ashtray', 'asparagus', 'atomizer', 'avocado', 'award',
+ 'awning', 'ax', 'baby_buggy', 'basketball_backboard', 'backpack',
+ 'handbag', 'suitcase', 'bagel', 'bagpipe', 'baguet', 'bait', 'ball',
+ 'ballet_skirt', 'balloon', 'bamboo', 'banana', 'Band_Aid', 'bandage',
+ 'bandanna', 'banjo', 'banner', 'barbell', 'barge', 'barrel',
+ 'barrette', 'barrow', 'baseball_base', 'baseball', 'baseball_bat',
+ 'baseball_cap', 'baseball_glove', 'basket', 'basketball_hoop',
+ 'basketball', 'bass_horn', 'bat_(animal)', 'bath_mat', 'bath_towel',
+ 'bathrobe', 'bathtub', 'batter_(food)', 'battery', 'beachball', 'bead',
+ 'beaker', 'bean_curd', 'beanbag', 'beanie', 'bear', 'bed',
+ 'bedspread', 'cow', 'beef_(food)', 'beeper', 'beer_bottle', 'beer_can',
+ 'beetle', 'bell', 'bell_pepper', 'belt', 'belt_buckle', 'bench',
+ 'beret', 'bib', 'Bible', 'bicycle', 'visor', 'binder', 'binoculars',
+ 'bird', 'birdfeeder', 'birdbath', 'birdcage', 'birdhouse',
+ 'birthday_cake', 'birthday_card', 'biscuit_(bread)', 'pirate_flag',
+ 'black_sheep', 'blackboard', 'blanket', 'blazer', 'blender', 'blimp',
+ 'blinker', 'blueberry', 'boar', 'gameboard', 'boat', 'bobbin',
+ 'bobby_pin', 'boiled_egg', 'bolo_tie', 'deadbolt', 'bolt', 'bonnet',
+ 'book', 'book_bag', 'bookcase', 'booklet', 'bookmark',
+ 'boom_microphone', 'boot', 'bottle', 'bottle_opener', 'bouquet',
+ 'bow_(weapon)', 'bow_(decorative_ribbons)', 'bow-tie', 'bowl',
+ 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'bowling_pin',
+ 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
+ 'bread-bin', 'breechcloth', 'bridal_gown', 'briefcase',
+ 'bristle_brush', 'broccoli', 'broach', 'broom', 'brownie',
+ 'brussels_sprouts', 'bubble_gum', 'bucket', 'horse_buggy', 'bull',
+ 'bulldog', 'bulldozer', 'bullet_train', 'bulletin_board',
+ 'bulletproof_vest', 'bullhorn', 'corned_beef', 'bun', 'bunk_bed',
+ 'buoy', 'burrito', 'bus_(vehicle)', 'business_card', 'butcher_knife',
+ 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
+ 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
+ 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
+ 'can', 'can_opener', 'candelabrum', 'candle', 'candle_holder',
+ 'candy_bar', 'candy_cane', 'walking_cane', 'canister', 'cannon',
+ 'canoe', 'cantaloup', 'canteen', 'cap_(headwear)', 'bottle_cap',
+ 'cape', 'cappuccino', 'car_(automobile)', 'railcar_(part_of_a_train)',
+ 'elevator_car', 'car_battery', 'identity_card', 'card', 'cardigan',
+ 'cargo_ship', 'carnation', 'horse_carriage', 'carrot', 'tote_bag',
+ 'cart', 'carton', 'cash_register', 'casserole', 'cassette', 'cast',
+ 'cat', 'cauliflower', 'caviar', 'cayenne_(spice)', 'CD_player',
+ 'celery', 'cellular_telephone', 'chain_mail', 'chair', 'chaise_longue',
+ 'champagne', 'chandelier', 'chap', 'checkbook', 'checkerboard',
+ 'cherry', 'chessboard', 'chest_of_drawers_(furniture)',
+ 'chicken_(animal)', 'chicken_wire', 'chickpea', 'Chihuahua',
+ 'chili_(vegetable)', 'chime', 'chinaware', 'crisp_(potato_chip)',
+ 'poker_chip', 'chocolate_bar', 'chocolate_cake', 'chocolate_milk',
+ 'chocolate_mousse', 'choker', 'chopping_board', 'chopstick',
+ 'Christmas_tree', 'slide', 'cider', 'cigar_box', 'cigarette',
+ 'cigarette_case', 'cistern', 'clarinet', 'clasp', 'cleansing_agent',
+ 'clementine', 'clip', 'clipboard', 'clock', 'clock_tower',
+ 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster', 'coat',
+ 'coat_hanger', 'coatrack', 'cock', 'coconut', 'coffee_filter',
+ 'coffee_maker', 'coffee_table', 'coffeepot', 'coil', 'coin',
+ 'colander', 'coleslaw', 'coloring_material', 'combination_lock',
+ 'pacifier', 'comic_book', 'computer_keyboard', 'concrete_mixer',
+ 'cone', 'control', 'convertible_(automobile)', 'sofa_bed', 'cookie',
+ 'cookie_jar', 'cooking_utensil', 'cooler_(for_food)',
+ 'cork_(bottle_plug)', 'corkboard', 'corkscrew', 'edible_corn',
+ 'cornbread', 'cornet', 'cornice', 'cornmeal', 'corset',
+ 'romaine_lettuce', 'costume', 'cougar', 'coverall', 'cowbell',
+ 'cowboy_hat', 'crab_(animal)', 'cracker', 'crape', 'crate', 'crayon',
+ 'cream_pitcher', 'credit_card', 'crescent_roll', 'crib', 'crock_pot',
+ 'crossbar', 'crouton', 'crow', 'crown', 'crucifix', 'cruise_ship',
+ 'police_cruiser', 'crumb', 'crutch', 'cub_(animal)', 'cube',
+ 'cucumber', 'cufflink', 'cup', 'trophy_cup', 'cupcake', 'hair_curler',
+ 'curling_iron', 'curtain', 'cushion', 'custard', 'cutting_tool',
+ 'cylinder', 'cymbal', 'dachshund', 'dagger', 'dartboard',
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
+ 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
+ 'dishwasher_detergent', 'diskette', 'dispenser', 'Dixie_cup', 'dog',
+ 'dog_collar', 'doll', 'dollar', 'dolphin', 'domestic_ass', 'eye_mask',
+ 'doorbell', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
+ 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
+ 'dresser', 'drill', 'drinking_fountain', 'drone', 'dropper',
+ 'drum_(musical_instrument)', 'drumstick', 'duck', 'duckling',
+ 'duct_tape', 'duffel_bag', 'dumbbell', 'dumpster', 'dustpan',
+ 'Dutch_oven', 'eagle', 'earphone', 'earplug', 'earring', 'easel',
+ 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
+ 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
+ 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
+ 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
+ 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
+ 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
+ 'fireplug', 'fish', 'fish_(food)', 'fishbowl', 'fishing_boat',
+ 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flash',
+ 'flashlight', 'fleece', 'flip-flop_(sandal)', 'flipper_(footwear)',
+ 'flower_arrangement', 'flute_glass', 'foal', 'folding_chair',
+ 'food_processor', 'football_(American)', 'football_helmet',
+ 'footstool', 'fork', 'forklift', 'freight_car', 'French_toast',
+ 'freshener', 'frisbee', 'frog', 'fruit_juice', 'fruit_salad',
+ 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
+ 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
+ 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'giant_panda',
+ 'gift_wrap', 'ginger', 'giraffe', 'cincture',
+ 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
+ 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
+ 'gorilla', 'gourd', 'surgical_gown', 'grape', 'grasshopper', 'grater',
+ 'gravestone', 'gravy_boat', 'green_bean', 'green_onion', 'griddle',
+ 'grillroom', 'grinder_(tool)', 'grits', 'grizzly', 'grocery_bag',
+ 'guacamole', 'guitar', 'gull', 'gun', 'hair_spray', 'hairbrush',
+ 'hairnet', 'hairpin', 'ham', 'hamburger', 'hammer', 'hammock',
+ 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
+ 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
+ 'hardback_book', 'harmonium', 'hat', 'hatbox', 'hatch', 'veil',
+ 'headband', 'headboard', 'headlight', 'headscarf', 'headset',
+ 'headstall_(for_horses)', 'hearing_aid', 'heart', 'heater',
+ 'helicopter', 'helmet', 'heron', 'highchair', 'hinge', 'hippopotamus',
+ 'hockey_stick', 'hog', 'home_plate_(baseball)', 'honey', 'fume_hood',
+ 'hook', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
+ 'ice_tea', 'igniter', 'incense', 'inhaler', 'iPod',
+ 'iron_(for_clothing)', 'ironing_board', 'jacket', 'jam', 'jean',
+ 'jeep', 'jelly_bean', 'jersey', 'jet_plane', 'jewelry', 'joystick',
+ 'jumpsuit', 'kayak', 'keg', 'kennel', 'kettle', 'key', 'keycard',
+ 'kilt', 'kimono', 'kitchen_sink', 'kitchen_table', 'kite', 'kitten',
+ 'kiwi_fruit', 'knee_pad', 'knife', 'knight_(chess_piece)',
+ 'knitting_needle', 'knob', 'knocker_(on_a_door)', 'koala', 'lab_coat',
+ 'ladder', 'ladle', 'ladybug', 'lamb_(animal)', 'lamb-chop', 'lamp',
+ 'lamppost', 'lampshade', 'lantern', 'lanyard', 'laptop_computer',
+ 'lasagna', 'latch', 'lawn_mower', 'leather', 'legging_(clothing)',
+ 'Lego', 'lemon', 'lemonade', 'lettuce', 'license_plate', 'life_buoy',
+ 'life_jacket', 'lightbulb', 'lightning_rod', 'lime', 'limousine',
+ 'linen_paper', 'lion', 'lip_balm', 'lipstick', 'liquor', 'lizard',
+ 'Loafer_(type_of_shoe)', 'log', 'lollipop', 'lotion',
+ 'speaker_(stero_equipment)', 'loveseat', 'machine_gun', 'magazine',
+ 'magnet', 'mail_slot', 'mailbox_(at_home)', 'mallet', 'mammoth',
+ 'mandarin_orange', 'manger', 'manhole', 'map', 'marker', 'martini',
+ 'mascot', 'mashed_potato', 'masher', 'mask', 'mast',
+ 'mat_(gym_equipment)', 'matchbox', 'mattress', 'measuring_cup',
+ 'measuring_stick', 'meatball', 'medicine', 'melon', 'microphone',
+ 'microscope', 'microwave_oven', 'milestone', 'milk', 'minivan',
+ 'mint_candy', 'mirror', 'mitten', 'mixer_(kitchen_tool)', 'money',
+ 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
+ 'motor_scooter', 'motor_vehicle', 'motorboat', 'motorcycle',
+ 'mound_(baseball)', 'mouse_(animal_rodent)',
+ 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
+ 'music_stool', 'musical_instrument', 'nailfile', 'nameplate', 'napkin',
+ 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newsstand',
+ 'nightshirt', 'nosebag_(for_animals)', 'noseband_(for_animals)',
+ 'notebook', 'notepad', 'nut', 'nutcracker', 'oar', 'octopus_(food)',
+ 'octopus_(animal)', 'oil_lamp', 'olive_oil', 'omelet', 'onion',
+ 'orange_(fruit)', 'orange_juice', 'oregano', 'ostrich', 'ottoman',
+ 'overalls_(clothing)', 'owl', 'packet', 'inkpad', 'pad', 'paddle',
+ 'padlock', 'paintbox', 'paintbrush', 'painting', 'pajamas', 'palette',
+ 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake', 'pantyhose',
+ 'papaya', 'paperclip', 'paper_plate', 'paper_towel', 'paperback_book',
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)',
+ 'parchment', 'parka', 'parking_meter', 'parrot',
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'pegboard',
+ 'pelican', 'pen', 'pencil', 'pencil_box', 'pencil_sharpener',
+ 'pendulum', 'penguin', 'pennant', 'penny_(coin)', 'pepper',
+ 'pepper_mill', 'perfume', 'persimmon', 'baby', 'pet', 'petfood',
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
+ 'plate', 'platter', 'playing_card', 'playpen', 'pliers',
+ 'plow_(farm_equipment)', 'pocket_watch', 'pocketknife',
+ 'poker_(fire_stirring_tool)', 'pole', 'police_van', 'polo_shirt',
+ 'poncho', 'pony', 'pool_table', 'pop_(soda)', 'portrait',
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
+ 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'printer',
+ 'projectile_(weapon)', 'projector', 'propeller', 'prune', 'pudding',
+ 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher', 'puppet',
+ 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit', 'race_car',
+ 'racket', 'radar', 'radiator', 'radio_receiver', 'radish', 'raft',
+ 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
+ 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
+ 'recliner', 'record_player', 'red_cabbage', 'reflector',
+ 'remote_control', 'rhinoceros', 'rib_(food)', 'rifle', 'ring',
+ 'river_boat', 'road_map', 'robe', 'rocking_chair', 'roller_skate',
+ 'Rollerblade', 'rolling_pin', 'root_beer',
+ 'router_(computer_equipment)', 'rubber_band', 'runner_(carpet)',
+ 'plastic_bag', 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag',
+ 'safety_pin', 'sail', 'salad', 'salad_plate', 'salami',
+ 'salmon_(fish)', 'salmon_(food)', 'salsa', 'saltshaker',
+ 'sandal_(type_of_shoe)', 'sandwich', 'satchel', 'saucepan', 'saucer',
+ 'sausage', 'sawhorse', 'saxophone', 'scale_(measuring_instrument)',
+ 'scarecrow', 'scarf', 'school_bus', 'scissors', 'scoreboard',
+ 'scrambled_eggs', 'scraper', 'scratcher', 'screwdriver',
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
+ 'seashell', 'seedling', 'serving_dish', 'sewing_machine', 'shaker',
+ 'shampoo', 'shark', 'sharpener', 'Sharpie', 'shaver_(electric)',
+ 'shaving_cream', 'shawl', 'shears', 'sheep', 'shepherd_dog',
+ 'sherbert', 'shield', 'shirt', 'shoe', 'shopping_bag', 'shopping_cart',
+ 'short_pants', 'shot_glass', 'shoulder_bag', 'shovel', 'shower_head',
+ 'shower_curtain', 'shredder_(for_paper)', 'sieve', 'signboard', 'silo',
+ 'sink', 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka',
+ 'ski_pole', 'skirt', 'sled', 'sleeping_bag', 'sling_(bandage)',
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'soda_fountain',
+ 'carbonated_water', 'sofa', 'softball', 'solar_array', 'sombrero',
+ 'soup', 'soup_bowl', 'soupspoon', 'sour_cream', 'soya_milk',
+ 'space_shuttle', 'sparkler_(fireworks)', 'spatula', 'spear',
+ 'spectacles', 'spice_rack', 'spider', 'sponge', 'spoon', 'sportswear',
+ 'spotlight', 'squirrel', 'stapler_(stapling_machine)', 'starfish',
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife',
+ 'steamer_(kitchen_appliance)', 'steering_wheel', 'stencil',
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
+ 'stirrup', 'stockings_(leg_wear)', 'stool', 'stop_sign', 'brake_light',
+ 'stove', 'strainer', 'strap', 'straw_(for_drinking)', 'strawberry',
+ 'street_sign', 'streetlight', 'string_cheese', 'stylus', 'subwoofer',
+ 'sugar_bowl', 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower',
+ 'sunglasses', 'sunhat', 'sunscreen', 'surfboard', 'sushi', 'mop',
+ 'sweat_pants', 'sweatband', 'sweater', 'sweatshirt', 'sweet_potato',
+ 'swimsuit', 'sword', 'syringe', 'Tabasco_sauce', 'table-tennis_table',
+ 'table', 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag',
+ 'taillight', 'tambourine', 'army_tank', 'tank_(storage_vessel)',
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
+ 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
+ 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
+ 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
+ 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
+ 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
+ 'tray', 'tree_house', 'trench_coat', 'triangle_(musical_instrument)',
+ 'tricycle', 'tripod', 'trousers', 'truck', 'truffle_(chocolate)',
+ 'trunk', 'vat', 'turban', 'turkey_(bird)', 'turkey_(food)', 'turnip',
+ 'turtle', 'turtleneck_(clothing)', 'typewriter', 'umbrella',
+ 'underwear', 'unicycle', 'urinal', 'urn', 'vacuum_cleaner', 'valve',
+ 'vase', 'vending_machine', 'vent', 'videotape', 'vinegar', 'violin',
+ 'vodka', 'volleyball', 'vulture', 'waffle', 'waffle_iron', 'wagon',
+ 'wagon_wheel', 'walking_stick', 'wall_clock', 'wall_socket', 'wallet',
+ 'walrus', 'wardrobe', 'wasabi', 'automatic_washer', 'watch',
+ 'water_bottle', 'water_cooler', 'water_faucet', 'water_filter',
+ 'water_heater', 'water_jug', 'water_gun', 'water_scooter', 'water_ski',
+ 'water_tower', 'watering_can', 'watermelon', 'weathervane', 'webcam',
+ 'wedding_cake', 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair',
+ 'whipped_cream', 'whiskey', 'whistle', 'wick', 'wig', 'wind_chime',
+ 'windmill', 'window_box_(for_plants)', 'windshield_wiper', 'windsock',
+ 'wine_bottle', 'wine_bucket', 'wineglass', 'wing_chair',
+ 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon', 'wreath',
+ 'wrench', 'wristband', 'wristlet', 'yacht', 'yak', 'yogurt',
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini')
+
+ def load_annotations(self, ann_file):
+ """Load annotation from lvis style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from LVIS api.
+ """
+
+ try:
+ import lvis
+ assert lvis.__version__ >= '10.5.3'
+ from lvis import LVIS
+ except AssertionError:
+ raise AssertionError('Incompatible version of lvis is installed. '
+ 'Run pip uninstall lvis first. Then run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis. ')
+ except ImportError:
+ raise ImportError('Package lvis is not installed. Please run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis.')
+ self.coco = LVIS(ann_file)
+ self.cat_ids = self.coco.get_cat_ids()
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ if info['file_name'].startswith('COCO'):
+ # Convert form the COCO 2014 file naming convention of
+ # COCO_[train/val/test]2014_000000000000.jpg to the 2017
+ # naming convention of 000000000000.jpg
+ # (LVIS v1 will fix this naming issue)
+ info['filename'] = info['file_name'][-16:]
+ else:
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ return data_infos
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=np.arange(0.5, 0.96, 0.05)):
+ """Evaluation in LVIS protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None):
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float]): IoU threshold used for evaluating
+ recalls. If set to a list, the average recall of all IoUs will
+ also be computed. Default: 0.5.
+
+ Returns:
+ dict[str, float]: LVIS style metrics.
+ """
+
+ try:
+ import lvis
+ assert lvis.__version__ >= '10.5.3'
+ from lvis import LVISResults, LVISEval
+ except AssertionError:
+ raise AssertionError('Incompatible version of lvis is installed. '
+ 'Run pip uninstall lvis first. Then run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis. ')
+ except ImportError:
+ raise ImportError('Package lvis is not installed. Please run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis.')
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError('metric {} is not supported'.format(metric))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ # get original api
+ lvis_gt = self.coco
+ for metric in metrics:
+ msg = 'Evaluating {}...'.format(metric)
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results['AR@{}'.format(num)] = ar[i]
+ log_msg.append('\nAR@{}\t{:.4f}'.format(num, ar[i]))
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files:
+ raise KeyError('{} is not in results'.format(metric))
+ try:
+ lvis_dt = LVISResults(lvis_gt, result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ lvis_eval = LVISEval(lvis_gt, lvis_dt, iou_type)
+ lvis_eval.params.imgIds = self.img_ids
+ if metric == 'proposal':
+ lvis_eval.params.useCats = 0
+ lvis_eval.params.maxDets = list(proposal_nums)
+ lvis_eval.evaluate()
+ lvis_eval.accumulate()
+ lvis_eval.summarize()
+ for k, v in lvis_eval.get_results().items():
+ if k.startswith('AR'):
+ val = float('{:.3f}'.format(float(v)))
+ eval_results[k] = val
+ else:
+ lvis_eval.evaluate()
+ lvis_eval.accumulate()
+ lvis_eval.summarize()
+ lvis_results = lvis_eval.get_results()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = lvis_eval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.load_cats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ for k, v in lvis_results.items():
+ if k.startswith('AP'):
+ key = '{}_{}'.format(metric, k)
+ val = float('{:.3f}'.format(float(v)))
+ eval_results[key] = val
+ ap_summary = ' '.join([
+ '{}:{:.3f}'.format(k, float(v))
+ for k, v in lvis_results.items() if k.startswith('AP')
+ ])
+ eval_results['{}_mAP_copypaste'.format(metric)] = ap_summary
+ lvis_eval.print_results()
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
+
+
+LVISDataset = LVISV05Dataset
+DATASETS.register_module(name='LVISDataset', module=LVISDataset)
+
+
+@DATASETS.register_module()
+class LVISV1Dataset(LVISDataset):
+
+ CLASSES = (
+ 'aerosol_can', 'air_conditioner', 'airplane', 'alarm_clock', 'alcohol',
+ 'alligator', 'almond', 'ambulance', 'amplifier', 'anklet', 'antenna',
+ 'apple', 'applesauce', 'apricot', 'apron', 'aquarium',
+ 'arctic_(type_of_shoe)', 'armband', 'armchair', 'armoire', 'armor',
+ 'artichoke', 'trash_can', 'ashtray', 'asparagus', 'atomizer',
+ 'avocado', 'award', 'awning', 'ax', 'baboon', 'baby_buggy',
+ 'basketball_backboard', 'backpack', 'handbag', 'suitcase', 'bagel',
+ 'bagpipe', 'baguet', 'bait', 'ball', 'ballet_skirt', 'balloon',
+ 'bamboo', 'banana', 'Band_Aid', 'bandage', 'bandanna', 'banjo',
+ 'banner', 'barbell', 'barge', 'barrel', 'barrette', 'barrow',
+ 'baseball_base', 'baseball', 'baseball_bat', 'baseball_cap',
+ 'baseball_glove', 'basket', 'basketball', 'bass_horn', 'bat_(animal)',
+ 'bath_mat', 'bath_towel', 'bathrobe', 'bathtub', 'batter_(food)',
+ 'battery', 'beachball', 'bead', 'bean_curd', 'beanbag', 'beanie',
+ 'bear', 'bed', 'bedpan', 'bedspread', 'cow', 'beef_(food)', 'beeper',
+ 'beer_bottle', 'beer_can', 'beetle', 'bell', 'bell_pepper', 'belt',
+ 'belt_buckle', 'bench', 'beret', 'bib', 'Bible', 'bicycle', 'visor',
+ 'billboard', 'binder', 'binoculars', 'bird', 'birdfeeder', 'birdbath',
+ 'birdcage', 'birdhouse', 'birthday_cake', 'birthday_card',
+ 'pirate_flag', 'black_sheep', 'blackberry', 'blackboard', 'blanket',
+ 'blazer', 'blender', 'blimp', 'blinker', 'blouse', 'blueberry',
+ 'gameboard', 'boat', 'bob', 'bobbin', 'bobby_pin', 'boiled_egg',
+ 'bolo_tie', 'deadbolt', 'bolt', 'bonnet', 'book', 'bookcase',
+ 'booklet', 'bookmark', 'boom_microphone', 'boot', 'bottle',
+ 'bottle_opener', 'bouquet', 'bow_(weapon)', 'bow_(decorative_ribbons)',
+ 'bow-tie', 'bowl', 'pipe_bowl', 'bowler_hat', 'bowling_ball', 'box',
+ 'boxing_glove', 'suspenders', 'bracelet', 'brass_plaque', 'brassiere',
+ 'bread-bin', 'bread', 'breechcloth', 'bridal_gown', 'briefcase',
+ 'broccoli', 'broach', 'broom', 'brownie', 'brussels_sprouts',
+ 'bubble_gum', 'bucket', 'horse_buggy', 'bull', 'bulldog', 'bulldozer',
+ 'bullet_train', 'bulletin_board', 'bulletproof_vest', 'bullhorn',
+ 'bun', 'bunk_bed', 'buoy', 'burrito', 'bus_(vehicle)', 'business_card',
+ 'butter', 'butterfly', 'button', 'cab_(taxi)', 'cabana', 'cabin_car',
+ 'cabinet', 'locker', 'cake', 'calculator', 'calendar', 'calf',
+ 'camcorder', 'camel', 'camera', 'camera_lens', 'camper_(vehicle)',
+ 'can', 'can_opener', 'candle', 'candle_holder', 'candy_bar',
+ 'candy_cane', 'walking_cane', 'canister', 'canoe', 'cantaloup',
+ 'canteen', 'cap_(headwear)', 'bottle_cap', 'cape', 'cappuccino',
+ 'car_(automobile)', 'railcar_(part_of_a_train)', 'elevator_car',
+ 'car_battery', 'identity_card', 'card', 'cardigan', 'cargo_ship',
+ 'carnation', 'horse_carriage', 'carrot', 'tote_bag', 'cart', 'carton',
+ 'cash_register', 'casserole', 'cassette', 'cast', 'cat', 'cauliflower',
+ 'cayenne_(spice)', 'CD_player', 'celery', 'cellular_telephone',
+ 'chain_mail', 'chair', 'chaise_longue', 'chalice', 'chandelier',
+ 'chap', 'checkbook', 'checkerboard', 'cherry', 'chessboard',
+ 'chicken_(animal)', 'chickpea', 'chili_(vegetable)', 'chime',
+ 'chinaware', 'crisp_(potato_chip)', 'poker_chip', 'chocolate_bar',
+ 'chocolate_cake', 'chocolate_milk', 'chocolate_mousse', 'choker',
+ 'chopping_board', 'chopstick', 'Christmas_tree', 'slide', 'cider',
+ 'cigar_box', 'cigarette', 'cigarette_case', 'cistern', 'clarinet',
+ 'clasp', 'cleansing_agent', 'cleat_(for_securing_rope)', 'clementine',
+ 'clip', 'clipboard', 'clippers_(for_plants)', 'cloak', 'clock',
+ 'clock_tower', 'clothes_hamper', 'clothespin', 'clutch_bag', 'coaster',
+ 'coat', 'coat_hanger', 'coatrack', 'cock', 'cockroach',
+ 'cocoa_(beverage)', 'coconut', 'coffee_maker', 'coffee_table',
+ 'coffeepot', 'coil', 'coin', 'colander', 'coleslaw',
+ 'coloring_material', 'combination_lock', 'pacifier', 'comic_book',
+ 'compass', 'computer_keyboard', 'condiment', 'cone', 'control',
+ 'convertible_(automobile)', 'sofa_bed', 'cooker', 'cookie',
+ 'cooking_utensil', 'cooler_(for_food)', 'cork_(bottle_plug)',
+ 'corkboard', 'corkscrew', 'edible_corn', 'cornbread', 'cornet',
+ 'cornice', 'cornmeal', 'corset', 'costume', 'cougar', 'coverall',
+ 'cowbell', 'cowboy_hat', 'crab_(animal)', 'crabmeat', 'cracker',
+ 'crape', 'crate', 'crayon', 'cream_pitcher', 'crescent_roll', 'crib',
+ 'crock_pot', 'crossbar', 'crouton', 'crow', 'crowbar', 'crown',
+ 'crucifix', 'cruise_ship', 'police_cruiser', 'crumb', 'crutch',
+ 'cub_(animal)', 'cube', 'cucumber', 'cufflink', 'cup', 'trophy_cup',
+ 'cupboard', 'cupcake', 'hair_curler', 'curling_iron', 'curtain',
+ 'cushion', 'cylinder', 'cymbal', 'dagger', 'dalmatian', 'dartboard',
+ 'date_(fruit)', 'deck_chair', 'deer', 'dental_floss', 'desk',
+ 'detergent', 'diaper', 'diary', 'die', 'dinghy', 'dining_table', 'tux',
+ 'dish', 'dish_antenna', 'dishrag', 'dishtowel', 'dishwasher',
+ 'dishwasher_detergent', 'dispenser', 'diving_board', 'Dixie_cup',
+ 'dog', 'dog_collar', 'doll', 'dollar', 'dollhouse', 'dolphin',
+ 'domestic_ass', 'doorknob', 'doormat', 'doughnut', 'dove', 'dragonfly',
+ 'drawer', 'underdrawers', 'dress', 'dress_hat', 'dress_suit',
+ 'dresser', 'drill', 'drone', 'dropper', 'drum_(musical_instrument)',
+ 'drumstick', 'duck', 'duckling', 'duct_tape', 'duffel_bag', 'dumbbell',
+ 'dumpster', 'dustpan', 'eagle', 'earphone', 'earplug', 'earring',
+ 'easel', 'eclair', 'eel', 'egg', 'egg_roll', 'egg_yolk', 'eggbeater',
+ 'eggplant', 'electric_chair', 'refrigerator', 'elephant', 'elk',
+ 'envelope', 'eraser', 'escargot', 'eyepatch', 'falcon', 'fan',
+ 'faucet', 'fedora', 'ferret', 'Ferris_wheel', 'ferry', 'fig_(fruit)',
+ 'fighter_jet', 'figurine', 'file_cabinet', 'file_(tool)', 'fire_alarm',
+ 'fire_engine', 'fire_extinguisher', 'fire_hose', 'fireplace',
+ 'fireplug', 'first-aid_kit', 'fish', 'fish_(food)', 'fishbowl',
+ 'fishing_rod', 'flag', 'flagpole', 'flamingo', 'flannel', 'flap',
+ 'flash', 'flashlight', 'fleece', 'flip-flop_(sandal)',
+ 'flipper_(footwear)', 'flower_arrangement', 'flute_glass', 'foal',
+ 'folding_chair', 'food_processor', 'football_(American)',
+ 'football_helmet', 'footstool', 'fork', 'forklift', 'freight_car',
+ 'French_toast', 'freshener', 'frisbee', 'frog', 'fruit_juice',
+ 'frying_pan', 'fudge', 'funnel', 'futon', 'gag', 'garbage',
+ 'garbage_truck', 'garden_hose', 'gargle', 'gargoyle', 'garlic',
+ 'gasmask', 'gazelle', 'gelatin', 'gemstone', 'generator',
+ 'giant_panda', 'gift_wrap', 'ginger', 'giraffe', 'cincture',
+ 'glass_(drink_container)', 'globe', 'glove', 'goat', 'goggles',
+ 'goldfish', 'golf_club', 'golfcart', 'gondola_(boat)', 'goose',
+ 'gorilla', 'gourd', 'grape', 'grater', 'gravestone', 'gravy_boat',
+ 'green_bean', 'green_onion', 'griddle', 'grill', 'grits', 'grizzly',
+ 'grocery_bag', 'guitar', 'gull', 'gun', 'hairbrush', 'hairnet',
+ 'hairpin', 'halter_top', 'ham', 'hamburger', 'hammer', 'hammock',
+ 'hamper', 'hamster', 'hair_dryer', 'hand_glass', 'hand_towel',
+ 'handcart', 'handcuff', 'handkerchief', 'handle', 'handsaw',
+ 'hardback_book', 'harmonium', 'hat', 'hatbox', 'veil', 'headband',
+ 'headboard', 'headlight', 'headscarf', 'headset',
+ 'headstall_(for_horses)', 'heart', 'heater', 'helicopter', 'helmet',
+ 'heron', 'highchair', 'hinge', 'hippopotamus', 'hockey_stick', 'hog',
+ 'home_plate_(baseball)', 'honey', 'fume_hood', 'hook', 'hookah',
+ 'hornet', 'horse', 'hose', 'hot-air_balloon', 'hotplate', 'hot_sauce',
+ 'hourglass', 'houseboat', 'hummingbird', 'hummus', 'polar_bear',
+ 'icecream', 'popsicle', 'ice_maker', 'ice_pack', 'ice_skate',
+ 'igniter', 'inhaler', 'iPod', 'iron_(for_clothing)', 'ironing_board',
+ 'jacket', 'jam', 'jar', 'jean', 'jeep', 'jelly_bean', 'jersey',
+ 'jet_plane', 'jewel', 'jewelry', 'joystick', 'jumpsuit', 'kayak',
+ 'keg', 'kennel', 'kettle', 'key', 'keycard', 'kilt', 'kimono',
+ 'kitchen_sink', 'kitchen_table', 'kite', 'kitten', 'kiwi_fruit',
+ 'knee_pad', 'knife', 'knitting_needle', 'knob', 'knocker_(on_a_door)',
+ 'koala', 'lab_coat', 'ladder', 'ladle', 'ladybug', 'lamb_(animal)',
+ 'lamb-chop', 'lamp', 'lamppost', 'lampshade', 'lantern', 'lanyard',
+ 'laptop_computer', 'lasagna', 'latch', 'lawn_mower', 'leather',
+ 'legging_(clothing)', 'Lego', 'legume', 'lemon', 'lemonade', 'lettuce',
+ 'license_plate', 'life_buoy', 'life_jacket', 'lightbulb',
+ 'lightning_rod', 'lime', 'limousine', 'lion', 'lip_balm', 'liquor',
+ 'lizard', 'log', 'lollipop', 'speaker_(stero_equipment)', 'loveseat',
+ 'machine_gun', 'magazine', 'magnet', 'mail_slot', 'mailbox_(at_home)',
+ 'mallard', 'mallet', 'mammoth', 'manatee', 'mandarin_orange', 'manger',
+ 'manhole', 'map', 'marker', 'martini', 'mascot', 'mashed_potato',
+ 'masher', 'mask', 'mast', 'mat_(gym_equipment)', 'matchbox',
+ 'mattress', 'measuring_cup', 'measuring_stick', 'meatball', 'medicine',
+ 'melon', 'microphone', 'microscope', 'microwave_oven', 'milestone',
+ 'milk', 'milk_can', 'milkshake', 'minivan', 'mint_candy', 'mirror',
+ 'mitten', 'mixer_(kitchen_tool)', 'money',
+ 'monitor_(computer_equipment) computer_monitor', 'monkey', 'motor',
+ 'motor_scooter', 'motor_vehicle', 'motorcycle', 'mound_(baseball)',
+ 'mouse_(computer_equipment)', 'mousepad', 'muffin', 'mug', 'mushroom',
+ 'music_stool', 'musical_instrument', 'nailfile', 'napkin',
+ 'neckerchief', 'necklace', 'necktie', 'needle', 'nest', 'newspaper',
+ 'newsstand', 'nightshirt', 'nosebag_(for_animals)',
+ 'noseband_(for_animals)', 'notebook', 'notepad', 'nut', 'nutcracker',
+ 'oar', 'octopus_(food)', 'octopus_(animal)', 'oil_lamp', 'olive_oil',
+ 'omelet', 'onion', 'orange_(fruit)', 'orange_juice', 'ostrich',
+ 'ottoman', 'oven', 'overalls_(clothing)', 'owl', 'packet', 'inkpad',
+ 'pad', 'paddle', 'padlock', 'paintbrush', 'painting', 'pajamas',
+ 'palette', 'pan_(for_cooking)', 'pan_(metal_container)', 'pancake',
+ 'pantyhose', 'papaya', 'paper_plate', 'paper_towel', 'paperback_book',
+ 'paperweight', 'parachute', 'parakeet', 'parasail_(sports)', 'parasol',
+ 'parchment', 'parka', 'parking_meter', 'parrot',
+ 'passenger_car_(part_of_a_train)', 'passenger_ship', 'passport',
+ 'pastry', 'patty_(food)', 'pea_(food)', 'peach', 'peanut_butter',
+ 'pear', 'peeler_(tool_for_fruit_and_vegetables)', 'wooden_leg',
+ 'pegboard', 'pelican', 'pen', 'pencil', 'pencil_box',
+ 'pencil_sharpener', 'pendulum', 'penguin', 'pennant', 'penny_(coin)',
+ 'pepper', 'pepper_mill', 'perfume', 'persimmon', 'person', 'pet',
+ 'pew_(church_bench)', 'phonebook', 'phonograph_record', 'piano',
+ 'pickle', 'pickup_truck', 'pie', 'pigeon', 'piggy_bank', 'pillow',
+ 'pin_(non_jewelry)', 'pineapple', 'pinecone', 'ping-pong_ball',
+ 'pinwheel', 'tobacco_pipe', 'pipe', 'pistol', 'pita_(bread)',
+ 'pitcher_(vessel_for_liquid)', 'pitchfork', 'pizza', 'place_mat',
+ 'plate', 'platter', 'playpen', 'pliers', 'plow_(farm_equipment)',
+ 'plume', 'pocket_watch', 'pocketknife', 'poker_(fire_stirring_tool)',
+ 'pole', 'polo_shirt', 'poncho', 'pony', 'pool_table', 'pop_(soda)',
+ 'postbox_(public)', 'postcard', 'poster', 'pot', 'flowerpot', 'potato',
+ 'potholder', 'pottery', 'pouch', 'power_shovel', 'prawn', 'pretzel',
+ 'printer', 'projectile_(weapon)', 'projector', 'propeller', 'prune',
+ 'pudding', 'puffer_(fish)', 'puffin', 'pug-dog', 'pumpkin', 'puncher',
+ 'puppet', 'puppy', 'quesadilla', 'quiche', 'quilt', 'rabbit',
+ 'race_car', 'racket', 'radar', 'radiator', 'radio_receiver', 'radish',
+ 'raft', 'rag_doll', 'raincoat', 'ram_(animal)', 'raspberry', 'rat',
+ 'razorblade', 'reamer_(juicer)', 'rearview_mirror', 'receipt',
+ 'recliner', 'record_player', 'reflector', 'remote_control',
+ 'rhinoceros', 'rib_(food)', 'rifle', 'ring', 'river_boat', 'road_map',
+ 'robe', 'rocking_chair', 'rodent', 'roller_skate', 'Rollerblade',
+ 'rolling_pin', 'root_beer', 'router_(computer_equipment)',
+ 'rubber_band', 'runner_(carpet)', 'plastic_bag',
+ 'saddle_(on_an_animal)', 'saddle_blanket', 'saddlebag', 'safety_pin',
+ 'sail', 'salad', 'salad_plate', 'salami', 'salmon_(fish)',
+ 'salmon_(food)', 'salsa', 'saltshaker', 'sandal_(type_of_shoe)',
+ 'sandwich', 'satchel', 'saucepan', 'saucer', 'sausage', 'sawhorse',
+ 'saxophone', 'scale_(measuring_instrument)', 'scarecrow', 'scarf',
+ 'school_bus', 'scissors', 'scoreboard', 'scraper', 'screwdriver',
+ 'scrubbing_brush', 'sculpture', 'seabird', 'seahorse', 'seaplane',
+ 'seashell', 'sewing_machine', 'shaker', 'shampoo', 'shark',
+ 'sharpener', 'Sharpie', 'shaver_(electric)', 'shaving_cream', 'shawl',
+ 'shears', 'sheep', 'shepherd_dog', 'sherbert', 'shield', 'shirt',
+ 'shoe', 'shopping_bag', 'shopping_cart', 'short_pants', 'shot_glass',
+ 'shoulder_bag', 'shovel', 'shower_head', 'shower_cap',
+ 'shower_curtain', 'shredder_(for_paper)', 'signboard', 'silo', 'sink',
+ 'skateboard', 'skewer', 'ski', 'ski_boot', 'ski_parka', 'ski_pole',
+ 'skirt', 'skullcap', 'sled', 'sleeping_bag', 'sling_(bandage)',
+ 'slipper_(footwear)', 'smoothie', 'snake', 'snowboard', 'snowman',
+ 'snowmobile', 'soap', 'soccer_ball', 'sock', 'sofa', 'softball',
+ 'solar_array', 'sombrero', 'soup', 'soup_bowl', 'soupspoon',
+ 'sour_cream', 'soya_milk', 'space_shuttle', 'sparkler_(fireworks)',
+ 'spatula', 'spear', 'spectacles', 'spice_rack', 'spider', 'crawfish',
+ 'sponge', 'spoon', 'sportswear', 'spotlight', 'squid_(food)',
+ 'squirrel', 'stagecoach', 'stapler_(stapling_machine)', 'starfish',
+ 'statue_(sculpture)', 'steak_(food)', 'steak_knife', 'steering_wheel',
+ 'stepladder', 'step_stool', 'stereo_(sound_system)', 'stew', 'stirrer',
+ 'stirrup', 'stool', 'stop_sign', 'brake_light', 'stove', 'strainer',
+ 'strap', 'straw_(for_drinking)', 'strawberry', 'street_sign',
+ 'streetlight', 'string_cheese', 'stylus', 'subwoofer', 'sugar_bowl',
+ 'sugarcane_(plant)', 'suit_(clothing)', 'sunflower', 'sunglasses',
+ 'sunhat', 'surfboard', 'sushi', 'mop', 'sweat_pants', 'sweatband',
+ 'sweater', 'sweatshirt', 'sweet_potato', 'swimsuit', 'sword',
+ 'syringe', 'Tabasco_sauce', 'table-tennis_table', 'table',
+ 'table_lamp', 'tablecloth', 'tachometer', 'taco', 'tag', 'taillight',
+ 'tambourine', 'army_tank', 'tank_(storage_vessel)',
+ 'tank_top_(clothing)', 'tape_(sticky_cloth_or_paper)', 'tape_measure',
+ 'tapestry', 'tarp', 'tartan', 'tassel', 'tea_bag', 'teacup',
+ 'teakettle', 'teapot', 'teddy_bear', 'telephone', 'telephone_booth',
+ 'telephone_pole', 'telephoto_lens', 'television_camera',
+ 'television_set', 'tennis_ball', 'tennis_racket', 'tequila',
+ 'thermometer', 'thermos_bottle', 'thermostat', 'thimble', 'thread',
+ 'thumbtack', 'tiara', 'tiger', 'tights_(clothing)', 'timer', 'tinfoil',
+ 'tinsel', 'tissue_paper', 'toast_(food)', 'toaster', 'toaster_oven',
+ 'toilet', 'toilet_tissue', 'tomato', 'tongs', 'toolbox', 'toothbrush',
+ 'toothpaste', 'toothpick', 'cover', 'tortilla', 'tow_truck', 'towel',
+ 'towel_rack', 'toy', 'tractor_(farm_equipment)', 'traffic_light',
+ 'dirt_bike', 'trailer_truck', 'train_(railroad_vehicle)', 'trampoline',
+ 'tray', 'trench_coat', 'triangle_(musical_instrument)', 'tricycle',
+ 'tripod', 'trousers', 'truck', 'truffle_(chocolate)', 'trunk', 'vat',
+ 'turban', 'turkey_(food)', 'turnip', 'turtle', 'turtleneck_(clothing)',
+ 'typewriter', 'umbrella', 'underwear', 'unicycle', 'urinal', 'urn',
+ 'vacuum_cleaner', 'vase', 'vending_machine', 'vent', 'vest',
+ 'videotape', 'vinegar', 'violin', 'vodka', 'volleyball', 'vulture',
+ 'waffle', 'waffle_iron', 'wagon', 'wagon_wheel', 'walking_stick',
+ 'wall_clock', 'wall_socket', 'wallet', 'walrus', 'wardrobe',
+ 'washbasin', 'automatic_washer', 'watch', 'water_bottle',
+ 'water_cooler', 'water_faucet', 'water_heater', 'water_jug',
+ 'water_gun', 'water_scooter', 'water_ski', 'water_tower',
+ 'watering_can', 'watermelon', 'weathervane', 'webcam', 'wedding_cake',
+ 'wedding_ring', 'wet_suit', 'wheel', 'wheelchair', 'whipped_cream',
+ 'whistle', 'wig', 'wind_chime', 'windmill', 'window_box_(for_plants)',
+ 'windshield_wiper', 'windsock', 'wine_bottle', 'wine_bucket',
+ 'wineglass', 'blinder_(for_horses)', 'wok', 'wolf', 'wooden_spoon',
+ 'wreath', 'wrench', 'wristband', 'wristlet', 'yacht', 'yogurt',
+ 'yoke_(animal_equipment)', 'zebra', 'zucchini')
+
+ def load_annotations(self, ann_file):
+ try:
+ import lvis
+ assert lvis.__version__ >= '10.5.3'
+ from lvis import LVIS
+ except AssertionError:
+ raise AssertionError('Incompatible version of lvis is installed. '
+ 'Run pip uninstall lvis first. Then run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis. ')
+ except ImportError:
+ raise ImportError('Package lvis is not installed. Please run pip '
+ 'install mmlvis to install open-mmlab forked '
+ 'lvis.')
+ self.coco = LVIS(ann_file)
+ self.cat_ids = self.coco.get_cat_ids()
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ # coco_url is used in LVISv1 instead of file_name
+ # e.g. http://images.cocodataset.org/train2017/000000391895.jpg
+ # train/val split in specified in url
+ info['filename'] = info['coco_url'].replace(
+ 'http://images.cocodataset.org/', '')
+ data_infos.append(info)
+ return data_infos
diff --git a/mmdet/datasets/pipelines/__init__.py b/mmdet/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6f424debd1623e7511dd77da464a6639d816745
--- /dev/null
+++ b/mmdet/datasets/pipelines/__init__.py
@@ -0,0 +1,25 @@
+from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform,
+ ContrastTransform, EqualizeTransform, Rotate, Shear,
+ Translate)
+from .compose import Compose
+from .formating import (Collect, DefaultFormatBundle, ImageToTensor,
+ ToDataContainer, ToTensor, Transpose, to_tensor)
+from .instaboost import InstaBoost
+from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam,
+ LoadMultiChannelImageFromFiles, LoadProposals)
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize,
+ Pad, PhotoMetricDistortion, RandomCenterCropPad,
+ RandomCrop, RandomFlip, Resize, SegRescale)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations',
+ 'LoadImageFromFile', 'LoadImageFromWebcam',
+ 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'MultiScaleFlipAug',
+ 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale',
+ 'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu',
+ 'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear',
+ 'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform',
+ 'ContrastTransform', 'Translate'
+]
diff --git a/mmdet/datasets/pipelines/auto_augment.py b/mmdet/datasets/pipelines/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..e19adaec18a96cac4dbe1d8c2c9193e9901be1fb
--- /dev/null
+++ b/mmdet/datasets/pipelines/auto_augment.py
@@ -0,0 +1,890 @@
+import copy
+
+import cv2
+import mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+_MAX_LEVEL = 10
+
+
+def level_to_value(level, max_value):
+ """Map from level to values based on max_value."""
+ return (level / _MAX_LEVEL) * max_value
+
+
+def enhance_level_to_value(level, a=1.8, b=0.1):
+ """Map from level to values."""
+ return (level / _MAX_LEVEL) * a + b
+
+
+def random_negative(value, random_negative_prob):
+ """Randomly negate value based on random_negative_prob."""
+ return -value if np.random.rand() < random_negative_prob else value
+
+
+def bbox2fields():
+ """The key correspondence from bboxes to labels, masks and
+ segmentations."""
+ bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+ bbox2seg = {
+ 'gt_bboxes': 'gt_semantic_seg',
+ }
+ return bbox2label, bbox2mask, bbox2seg
+
+
+@PIPELINES.register_module()
+class AutoAugment(object):
+ """Auto augmentation.
+
+ This data augmentation is proposed in `Learning Data Augmentation
+ Strategies for Object Detection `_.
+
+ TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
+
+ Args:
+ policies (list[list[dict]]): The policies of auto augmentation. Each
+ policy in ``policies`` is a specific augmentation policy, and is
+ composed by several augmentations (dict). When AutoAugment is
+ called, a random policy in ``policies`` will be selected to
+ augment images.
+
+ Examples:
+ >>> replace = (104, 116, 124)
+ >>> policies = [
+ >>> [
+ >>> dict(type='Sharpness', prob=0.0, level=8),
+ >>> dict(
+ >>> type='Shear',
+ >>> prob=0.4,
+ >>> level=0,
+ >>> replace=replace,
+ >>> axis='x')
+ >>> ],
+ >>> [
+ >>> dict(
+ >>> type='Rotate',
+ >>> prob=0.6,
+ >>> level=10,
+ >>> replace=replace),
+ >>> dict(type='Color', prob=1.0, level=6)
+ >>> ]
+ >>> ]
+ >>> augmentation = AutoAugment(policies)
+ >>> img = np.ones(100, 100, 3)
+ >>> gt_bboxes = np.ones(10, 4)
+ >>> results = dict(img=img, gt_bboxes=gt_bboxes)
+ >>> results = augmentation(results)
+ """
+
+ def __init__(self, policies):
+ assert isinstance(policies, list) and len(policies) > 0, \
+ 'Policies must be a non-empty list.'
+ for policy in policies:
+ assert isinstance(policy, list) and len(policy) > 0, \
+ 'Each policy in policies must be a non-empty list.'
+ for augment in policy:
+ assert isinstance(augment, dict) and 'type' in augment, \
+ 'Each specific augmentation must be a dict with key' \
+ ' "type".'
+
+ self.policies = copy.deepcopy(policies)
+ self.transforms = [Compose(policy) for policy in self.policies]
+
+ def __call__(self, results):
+ transform = np.random.choice(self.transforms)
+ return transform(results)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(policies={self.policies})'
+
+
+@PIPELINES.register_module()
+class Shear(object):
+ """Apply Shear Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range [0,_MAX_LEVEL].
+ img_fill_val (int | float | tuple): The filled values for image border.
+ If float, the same fill value will be used for all the three
+ channels of image. If tuple, the should be 3 elements.
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for performing Shear and should be in
+ range [0, 1].
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ max_shear_magnitude (float): The maximum magnitude for Shear
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative. Should be in range [0,1]
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+
+ def __init__(self,
+ level,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ direction='horizontal',
+ max_shear_magnitude=0.3,
+ random_negative_prob=0.5,
+ interpolation='bilinear'):
+ assert isinstance(level, (int, float)), 'The level must be type ' \
+ f'int or float, got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, 'The level should be in range ' \
+ f'[0,{_MAX_LEVEL}], got {level}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must ' \
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), 'all ' \
+ 'elements of img_fill_val should between range [0,255].' \
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability of shear should be in ' \
+ f'range [0,1]. got {prob}.'
+ assert direction in ('horizontal', 'vertical'), 'direction must ' \
+ f'in be either "horizontal" or "vertical". got {direction}.'
+ assert isinstance(max_shear_magnitude, float), 'max_shear_magnitude ' \
+ f'should be type float. got {type(max_shear_magnitude)}.'
+ assert 0. <= max_shear_magnitude <= 1., 'Defaultly ' \
+ 'max_shear_magnitude should be in range [0,1]. ' \
+ f'got {max_shear_magnitude}.'
+ self.level = level
+ self.magnitude = level_to_value(level, max_shear_magnitude)
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.direction = direction
+ self.max_shear_magnitude = max_shear_magnitude
+ self.random_negative_prob = random_negative_prob
+ self.interpolation = interpolation
+
+ def _shear_img(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ interpolation='bilinear'):
+ """Shear the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ img_sheared = mmcv.imshear(
+ img,
+ magnitude,
+ direction,
+ border_value=self.img_fill_val,
+ interpolation=interpolation)
+ results[key] = img_sheared.astype(img.dtype)
+
+ def _shear_bboxes(self, results, magnitude):
+ """Shear the bboxes."""
+ h, w, c = results['img_shape']
+ if self.direction == 'horizontal':
+ shear_matrix = np.stack([[1, magnitude],
+ [0, 1]]).astype(np.float32) # [2, 2]
+ else:
+ shear_matrix = np.stack([[1, 0], [magnitude,
+ 1]]).astype(np.float32)
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_box, 1]
+ coordinates = coordinates[..., 0].transpose(
+ (2, 1, 0)).astype(np.float32) # [nb_box, 2, 4]
+ new_coords = np.matmul(shear_matrix[None, :, :],
+ coordinates) # [nb_box, 2, 4]
+ min_x = np.min(new_coords[:, 0, :], axis=-1)
+ min_y = np.min(new_coords[:, 1, :], axis=-1)
+ max_x = np.max(new_coords[:, 0, :], axis=-1)
+ max_y = np.max(new_coords[:, 1, :], axis=-1)
+ min_x = np.clip(min_x, a_min=0, a_max=w)
+ min_y = np.clip(min_y, a_min=0, a_max=h)
+ max_x = np.clip(max_x, a_min=min_x, a_max=w)
+ max_y = np.clip(max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _shear_masks(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Shear the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.shear((h, w),
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation)
+
+ def _shear_seg(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=255,
+ interpolation='bilinear'):
+ """Shear the segmentation maps."""
+ for key in results.get('seg_fields', []):
+ seg = results[key]
+ results[key] = mmcv.imshear(
+ seg,
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after shear
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to shear images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Sheared results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ magnitude = random_negative(self.magnitude, self.random_negative_prob)
+ self._shear_img(results, magnitude, self.direction, self.interpolation)
+ self._shear_bboxes(results, magnitude)
+ # fill_val set to 0 for background of mask.
+ self._shear_masks(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=0,
+ interpolation=self.interpolation)
+ self._shear_seg(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=self.seg_ignore_label,
+ interpolation=self.interpolation)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'direction={self.direction}, '
+ repr_str += f'max_shear_magnitude={self.max_shear_magnitude}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rotate(object):
+ """Apply Rotate Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range (0,_MAX_LEVEL].
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ center (int | float | tuple[float]): Center point (w, h) of the
+ rotation in the source image. If None, the center of the
+ image will be used. Same in ``mmcv.imrotate``.
+ img_fill_val (int | float | tuple): The fill value for image border.
+ If float, the same value will be used for all the three
+ channels of image. If tuple, the should be 3 elements (e.g.
+ equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for perform transformation and
+ should be in range 0 to 1.
+ max_rotate_angle (int | float): The maximum angles for rotate
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ """
+
+ def __init__(self,
+ level,
+ scale=1,
+ center=None,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ max_rotate_angle=30,
+ random_negative_prob=0.5):
+ assert isinstance(level, (int, float)), \
+ f'The level must be type int or float. got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.'
+ assert isinstance(scale, (int, float)), \
+ f'The scale must be type int or float. got type {type(scale)}.'
+ if isinstance(center, (int, float)):
+ center = (center, center)
+ elif isinstance(center, tuple):
+ assert len(center) == 2, 'center with type tuple must have '\
+ f'2 elements. got {len(center)} elements.'
+ else:
+ assert center is None, 'center must be None or type int, '\
+ f'float or tuple, got type {type(center)}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255]. '\
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
+ 'got {prob}.'
+ assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\
+ f'should be type int or float. got type {type(max_rotate_angle)}.'
+ self.level = level
+ self.scale = scale
+ # Rotation angle in degrees. Positive values mean
+ # clockwise rotation.
+ self.angle = level_to_value(level, max_rotate_angle)
+ self.center = center
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.max_rotate_angle = max_rotate_angle
+ self.random_negative_prob = random_negative_prob
+
+ def _rotate_img(self, results, angle, center=None, scale=1.0):
+ """Rotate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ angle (float): Rotation angle in degrees, positive values
+ mean clockwise rotation. Same in ``mmcv.imrotate``.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation. Same in ``mmcv.imrotate``.
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ img_rotated = mmcv.imrotate(
+ img, angle, center, scale, border_value=self.img_fill_val)
+ results[key] = img_rotated.astype(img.dtype)
+
+ def _rotate_bboxes(self, results, rotate_matrix):
+ """Rotate the bboxes."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_bbox, 1]
+ # pad 1 to convert from format [x, y] to homogeneous
+ # coordinates format [x, y, 1]
+ coordinates = np.concatenate(
+ (coordinates,
+ np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)),
+ axis=1) # [4, 3, nb_bbox, 1]
+ coordinates = coordinates.transpose(
+ (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1]
+ rotated_coords = np.matmul(rotate_matrix,
+ coordinates) # [nb_bbox, 4, 2, 1]
+ rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2]
+ min_x, min_y = np.min(
+ rotated_coords[:, :, 0], axis=1), np.min(
+ rotated_coords[:, :, 1], axis=1)
+ max_x, max_y = np.max(
+ rotated_coords[:, :, 0], axis=1), np.max(
+ rotated_coords[:, :, 1], axis=1)
+ min_x, min_y = np.clip(
+ min_x, a_min=0, a_max=w), np.clip(
+ min_y, a_min=0, a_max=h)
+ max_x, max_y = np.clip(
+ max_x, a_min=min_x, a_max=w), np.clip(
+ max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _rotate_masks(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=0):
+ """Rotate the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.rotate((h, w), angle, center, scale, fill_val)
+
+ def _rotate_seg(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=255):
+ """Rotate the segmentation map."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imrotate(
+ seg, angle, center, scale,
+ border_value=fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after rotate
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to rotate images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ h, w = results['img'].shape[:2]
+ center = self.center
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ angle = random_negative(self.angle, self.random_negative_prob)
+ self._rotate_img(results, angle, center, self.scale)
+ rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale)
+ self._rotate_bboxes(results, rotate_matrix)
+ self._rotate_masks(results, angle, center, self.scale, fill_val=0)
+ self._rotate_seg(
+ results, angle, center, self.scale, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'center={self.center}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'max_rotate_angle={self.max_rotate_angle}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Translate(object):
+ """Translate the images, bboxes, masks and segmentation maps horizontally
+ or vertically.
+
+ Args:
+ level (int | float): The level for Translate and should be in
+ range [0,_MAX_LEVEL].
+ prob (float): The probability for performing translation and
+ should be in range [0, 1].
+ img_fill_val (int | float | tuple): The filled value for image
+ border. If float, the same fill value will be used for all
+ the three channels of image. If tuple, the should be 3
+ elements (e.g. equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ max_translate_offset (int | float): The maximum pixel's offset for
+ Translate.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ min_size (int | float): The minimum pixel for filtering
+ invalid bboxes after the translation.
+ """
+
+ def __init__(self,
+ level,
+ prob=0.5,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ direction='horizontal',
+ max_translate_offset=250.,
+ random_negative_prob=0.5,
+ min_size=0):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level used for calculating Translate\'s offset should be ' \
+ 'in range [0,_MAX_LEVEL]'
+ assert 0 <= prob <= 1.0, \
+ 'The probability of translation should be in range [0, 1].'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, \
+ 'img_fill_val as tuple must have 3 elements.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError('img_fill_val must be type float or tuple.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255].'
+ assert direction in ('horizontal', 'vertical'), \
+ 'direction should be "horizontal" or "vertical".'
+ assert isinstance(max_translate_offset, (int, float)), \
+ 'The max_translate_offset must be type int or float.'
+ # the offset used for translation
+ self.offset = int(level_to_value(level, max_translate_offset))
+ self.level = level
+ self.prob = prob
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.direction = direction
+ self.max_translate_offset = max_translate_offset
+ self.random_negative_prob = random_negative_prob
+ self.min_size = min_size
+
+ def _translate_img(self, results, offset, direction='horizontal'):
+ """Translate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ results[key] = mmcv.imtranslate(
+ img, offset, direction, self.img_fill_val).astype(img.dtype)
+
+ def _translate_bboxes(self, results, offset):
+ """Shift bboxes horizontally or vertically, according to offset."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ if self.direction == 'horizontal':
+ min_x = np.maximum(0, min_x + offset)
+ max_x = np.minimum(w, max_x + offset)
+ elif self.direction == 'vertical':
+ min_y = np.maximum(0, min_y + offset)
+ max_y = np.minimum(h, max_y + offset)
+
+ # the boxes translated outside of image will be filtered along with
+ # the corresponding masks, by invoking ``_filter_invalid``.
+ results[key] = np.concatenate([min_x, min_y, max_x, max_y],
+ axis=-1)
+
+ def _translate_masks(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=0):
+ """Translate masks horizontally or vertically."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.translate((h, w), offset, direction, fill_val)
+
+ def _translate_seg(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=255):
+ """Translate segmentation maps horizontally or vertically."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imtranslate(seg, offset, direction,
+ fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_size=0):
+ """Filter bboxes and masks too small or translated out of image."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_size) & (bbox_h > min_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+ return results
+
+ def __call__(self, results):
+ """Call function to translate images, bounding boxes, masks and
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Translated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ offset = random_negative(self.offset, self.random_negative_prob)
+ self._translate_img(results, offset, self.direction)
+ self._translate_bboxes(results, offset)
+ # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks.
+ self._translate_masks(results, offset, self.direction)
+ # fill_val set to ``seg_ignore_label`` for the ignored value
+ # of segmentation map.
+ self._translate_seg(
+ results, offset, self.direction, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results, min_size=self.min_size)
+ return results
+
+
+@PIPELINES.register_module()
+class ColorTransform(object):
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_color_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get('img_fields', ['img']):
+ # NOTE defaultly the image should be BGR format
+ img = results[key]
+ results[key] = mmcv.adjust_color(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_color_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class EqualizeTransform(object):
+ """Apply Equalize transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ prob (float): The probability for performing Equalize transformation.
+ """
+
+ def __init__(self, prob=0.5):
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.prob = prob
+
+ def _imequalize(self, results):
+ """Equalizes the histogram of one image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.imequalize(img).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Equalize transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._imequalize(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class BrightnessTransform(object):
+ """Apply Brightness transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Brightness transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_brightness_img(self, results, factor=1.0):
+ """Adjust the brightness of image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_brightness(img,
+ factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Brightness transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_brightness_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ContrastTransform(object):
+ """Apply Contrast transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Contrast transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_contrast_img(self, results, factor=1.0):
+ """Adjust the image contrast."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_contrast(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Contrast transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_contrast_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/compose.py b/mmdet/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca48f1c935755c486edc2744e1713e2b5ba3cdc8
--- /dev/null
+++ b/mmdet/datasets/pipelines/compose.py
@@ -0,0 +1,51 @@
+import collections
+
+from mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/mmdet/datasets/pipelines/formating.py b/mmdet/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..5781341bd48766a740f23ebba7a85cf8993642d7
--- /dev/null
+++ b/mmdet/datasets/pipelines/formating.py
@@ -0,0 +1,364 @@
+from collections.abc import Sequence
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to transpose the channel order of data in results.
+
+ Args:
+ results (dict): Result dict contains the data to transpose.
+
+ Returns:
+ dict: The result dict contains the data transposed to \
+ ``self.order``.
+ """
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to \
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img",
+ "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
+ These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - proposals: (1)to tensor, (2)to DataContainer
+ - gt_bboxes: (1)to tensor, (2)to DataContainer
+ - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
+ - gt_labels: (1)to tensor, (2)to DataContainer
+ - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \
+ (3)to DataContainer (stack=True)
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with \
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ # add default meta keys
+ results = self._add_default_meta_keys(results)
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels']:
+ if key not in results:
+ continue
+ results[key] = DC(to_tensor(results[key]))
+ if 'gt_masks' in results:
+ results['gt_masks'] = DC(results['gt_masks'], cpu_only=True)
+ if 'gt_semantic_seg' in results:
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
+ return results
+
+ def _add_default_meta_keys(self, results):
+ """Add default meta keys.
+
+ We set default meta keys including `pad_shape`, `scale_factor` and
+ `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
+ `Pad` are implemented during the whole pipeline.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ results (dict): Updated result dict contains the data to convert.
+ """
+ img = results['img']
+ results.setdefault('pad_shape', img.shape)
+ results.setdefault('scale_factor', 1.0)
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results.setdefault(
+ 'img_norm_cfg',
+ dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "proposals", "gt_bboxes",
+ "gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple \
+ (h, w, c). Note that images may be zero padded on the \
+ bottom/right if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
+
+
+@PIPELINES.register_module()
+class WrapFieldsToLists(object):
+ """Wrap fields of the data dictionary into lists for evaluation.
+
+ This class can be used as a last step of a test or validation
+ pipeline for single image evaluation or inference.
+
+ Example:
+ >>> test_pipeline = [
+ >>> dict(type='LoadImageFromFile'),
+ >>> dict(type='Normalize',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True),
+ >>> dict(type='Pad', size_divisor=32),
+ >>> dict(type='ImageToTensor', keys=['img']),
+ >>> dict(type='Collect', keys=['img']),
+ >>> dict(type='WrapFieldsToLists')
+ >>> ]
+ """
+
+ def __call__(self, results):
+ """Call function to wrap fields into lists.
+
+ Args:
+ results (dict): Result dict contains the data to wrap.
+
+ Returns:
+ dict: The result dict where value of ``self.keys`` are wrapped \
+ into list.
+ """
+
+ # Wrap dict fields into lists
+ for key, val in results.items():
+ results[key] = [val]
+ return results
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/mmdet/datasets/pipelines/instaboost.py b/mmdet/datasets/pipelines/instaboost.py
new file mode 100644
index 0000000000000000000000000000000000000000..38b6819f60587a6e0c0f6d57bfda32bb3a7a4267
--- /dev/null
+++ b/mmdet/datasets/pipelines/instaboost.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class InstaBoost(object):
+ r"""Data augmentation method in `InstaBoost: Boosting Instance
+ Segmentation Via Probability Map Guided Copy-Pasting
+ `_.
+
+ Refer to https://github.com/GothicAi/Instaboost for implementation details.
+ """
+
+ def __init__(self,
+ action_candidate=('normal', 'horizontal', 'skip'),
+ action_prob=(1, 0, 0),
+ scale=(0.8, 1.2),
+ dx=15,
+ dy=15,
+ theta=(-1, 1),
+ color_prob=0.5,
+ hflag=False,
+ aug_ratio=0.5):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install instaboostfast" '
+ 'to install instaboostfast first for instaboost augmentation.')
+ self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
+ scale, dx, dy, theta,
+ color_prob, hflag)
+ self.aug_ratio = aug_ratio
+
+ def _load_anns(self, results):
+ labels = results['ann_info']['labels']
+ masks = results['ann_info']['masks']
+ bboxes = results['ann_info']['bboxes']
+ n = len(labels)
+
+ anns = []
+ for i in range(n):
+ label = labels[i]
+ bbox = bboxes[i]
+ mask = masks[i]
+ x1, y1, x2, y2 = bbox
+ # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
+ bbox = [x1, y1, x2 - x1, y2 - y1]
+ anns.append({
+ 'category_id': label,
+ 'segmentation': mask,
+ 'bbox': bbox
+ })
+
+ return anns
+
+ def _parse_anns(self, results, anns, img):
+ gt_bboxes = []
+ gt_labels = []
+ gt_masks_ann = []
+ for ann in anns:
+ x1, y1, w, h = ann['bbox']
+ # TODO: more essential bug need to be fixed in instaboost
+ if w <= 0 or h <= 0:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ gt_bboxes.append(bbox)
+ gt_labels.append(ann['category_id'])
+ gt_masks_ann.append(ann['segmentation'])
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ results['ann_info']['labels'] = gt_labels
+ results['ann_info']['bboxes'] = gt_bboxes
+ results['ann_info']['masks'] = gt_masks_ann
+ results['img'] = img
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ orig_type = img.dtype
+ anns = self._load_anns(results)
+ if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError('Please run "pip install instaboostfast" '
+ 'to install instaboostfast first.')
+ anns, img = instaboost.get_new_data(
+ anns, img.astype(np.uint8), self.cfg, background=None)
+
+ results = self._parse_anns(results, anns, img.astype(orig_type))
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/loading.py b/mmdet/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c1d11f364e29707069b881fdca6f99dc1a52680
--- /dev/null
+++ b/mmdet/datasets/pipelines/loading.py
@@ -0,0 +1,470 @@
+import os.path as osp
+
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+
+from mmdet.core import BitmapMasks, PolygonMasks
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadImageFromWebcam(LoadImageFromFile):
+ """Load an image from webcam.
+
+ Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
+ ``results['img']``.
+ """
+
+ def __call__(self, results):
+ """Call functions to add image meta information.
+
+ Args:
+ results (dict): Result dict with Webcam read image in
+ ``results['img']``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ img = results['img']
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = None
+ results['ori_filename'] = None
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+
+@PIPELINES.register_module()
+class LoadMultiChannelImageFromFiles(object):
+ """Load multi-channel images from a list of separate channel files.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename", which is expected to be a list of filenames).
+ Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='unchanged',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load multiple images and get images meta
+ information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded images and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = [
+ osp.join(results['img_prefix'], fname)
+ for fname in results['img_info']['filename']
+ ]
+ else:
+ filename = results['img_info']['filename']
+
+ img = []
+ for name in filename:
+ img_bytes = self.file_client.get(name)
+ img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
+ img = np.stack(img, axis=-1)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+ """Load mutiple types of annotations.
+
+ Args:
+ with_bbox (bool): Whether to parse and load the bbox annotation.
+ Default: True.
+ with_label (bool): Whether to parse and load the label annotation.
+ Default: True.
+ with_mask (bool): Whether to parse and load the mask annotation.
+ Default: False.
+ with_seg (bool): Whether to parse and load the semantic segmentation
+ annotation. Default: False.
+ poly2mask (bool): Whether to convert the instance masks from polygons
+ to bitmaps. Default: True.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=False,
+ with_seg=False,
+ poly2mask=True,
+ file_client_args=dict(backend='disk')):
+ self.with_bbox = with_bbox
+ self.with_label = with_label
+ self.with_mask = with_mask
+ self.with_seg = with_seg
+ self.poly2mask = poly2mask
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def _load_bboxes(self, results):
+ """Private function to load bounding box annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box annotations.
+ """
+
+ ann_info = results['ann_info']
+ results['gt_bboxes'] = ann_info['bboxes'].copy()
+
+ gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
+ if gt_bboxes_ignore is not None:
+ results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
+ results['bbox_fields'].append('gt_bboxes_ignore')
+ results['bbox_fields'].append('gt_bboxes')
+ return results
+
+ def _load_labels(self, results):
+ """Private function to load label annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded label annotations.
+ """
+
+ results['gt_labels'] = results['ann_info']['labels'].copy()
+ return results
+
+ def _poly2mask(self, mask_ann, img_h, img_w):
+ """Private function to convert masks represented with polygon to
+ bitmaps.
+
+ Args:
+ mask_ann (list | dict): Polygon mask annotation input.
+ img_h (int): The height of output mask.
+ img_w (int): The width of output mask.
+
+ Returns:
+ numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
+ """
+
+ if isinstance(mask_ann, list):
+ # polygon -- a single object might consist of multiple parts
+ # we merge all parts into one mask rle code
+ rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ rle = maskUtils.merge(rles)
+ elif isinstance(mask_ann['counts'], list):
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ else:
+ # rle
+ rle = mask_ann
+ mask = maskUtils.decode(rle)
+ return mask
+
+ def process_polygons(self, polygons):
+ """Convert polygons to list of ndarray and filter invalid polygons.
+
+ Args:
+ polygons (list[list]): Polygons of one instance.
+
+ Returns:
+ list[numpy.ndarray]: Processed polygons.
+ """
+
+ polygons = [np.array(p) for p in polygons]
+ valid_polygons = []
+ for polygon in polygons:
+ if len(polygon) % 2 == 0 and len(polygon) >= 6:
+ valid_polygons.append(polygon)
+ return valid_polygons
+
+ def _load_masks(self, results):
+ """Private function to load mask annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded mask annotations.
+ If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
+ :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
+ """
+
+ h, w = results['img_info']['height'], results['img_info']['width']
+ gt_masks = results['ann_info']['masks']
+ if self.poly2mask:
+ masks_all =[]
+ for mask in gt_masks:
+ if 'full' in mask:
+ full = self._poly2mask(mask['full'], h, w)*2
+ visible = self._poly2mask(mask['visible'], h, w)
+ full[visible==1] = 1
+ masks_all.append(full)
+ else:
+ print(mask)
+ asas
+ visible = self._poly2mask(mask['visible'], h, w)
+ masks_all.append(visible)
+
+ gt_masks = BitmapMasks(masks_all, h, w)
+ else:
+ gt_masks = PolygonMasks(
+ [self.process_polygons(polygons) for polygons in gt_masks], h,
+ w)
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+ return results
+
+ def _load_semantic_seg(self, results):
+ """Private function to load semantic segmentation annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ img_bytes = self.file_client.get(filename)
+ results['gt_semantic_seg'] = mmcv.imfrombytes(
+ img_bytes, flag='unchanged').squeeze()
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box, label, mask and
+ semantic segmentation annotations.
+ """
+
+ if self.with_bbox:
+ results = self._load_bboxes(results)
+ if results is None:
+ return None
+ if self.with_label:
+ results = self._load_labels(results)
+ if self.with_mask:
+ results = self._load_masks(results)
+ if self.with_seg:
+ results = self._load_semantic_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(with_bbox={self.with_bbox}, '
+ repr_str += f'with_label={self.with_label}, '
+ repr_str += f'with_mask={self.with_mask}, '
+ repr_str += f'with_seg={self.with_seg}, '
+ repr_str += f'poly2mask={self.poly2mask}, '
+ repr_str += f'poly2mask={self.file_client_args})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadProposals(object):
+ """Load proposal pipeline.
+
+ Required key is "proposals". Updated keys are "proposals", "bbox_fields".
+
+ Args:
+ num_max_proposals (int, optional): Maximum number of proposals to load.
+ If not specified, all proposals will be loaded.
+ """
+
+ def __init__(self, num_max_proposals=None):
+ self.num_max_proposals = num_max_proposals
+
+ def __call__(self, results):
+ """Call function to load proposals from file.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded proposal annotations.
+ """
+
+ proposals = results['proposals']
+ if proposals.shape[1] not in (4, 5):
+ raise AssertionError(
+ 'proposals should have shapes (n, 4) or (n, 5), '
+ f'but found {proposals.shape}')
+ proposals = proposals[:, :4]
+
+ if self.num_max_proposals is not None:
+ proposals = proposals[:self.num_max_proposals]
+
+ if len(proposals) == 0:
+ proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
+ results['proposals'] = proposals
+ results['bbox_fields'].append('proposals')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(num_max_proposals={self.num_max_proposals})'
+
+
+@PIPELINES.register_module()
+class FilterAnnotations(object):
+ """Filter invalid annotations.
+
+ Args:
+ min_gt_bbox_wh (tuple[int]): Minimum width and height of ground truth
+ boxes.
+ """
+
+ def __init__(self, min_gt_bbox_wh):
+ # TODO: add more filter options
+ self.min_gt_bbox_wh = min_gt_bbox_wh
+
+ def __call__(self, results):
+ assert 'gt_bboxes' in results
+ gt_bboxes = results['gt_bboxes']
+ w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
+ h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
+ keep = (w > self.min_gt_bbox_wh[0]) & (h > self.min_gt_bbox_wh[1])
+ if not keep.any():
+ return None
+ else:
+ keys = ('gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg')
+ for key in keys:
+ if key in results:
+ results[key] = results[key][keep]
+ return results
diff --git a/mmdet/datasets/pipelines/test_time_aug.py b/mmdet/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6226e040499882c99f15594c66ebf3d07829168
--- /dev/null
+++ b/mmdet/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,119 @@
+import warnings
+
+import mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=[(1333, 400), (1333, 800)],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (tuple | list[tuple] | None): Images scales for resizing.
+ scale_factor (float | list[float] | None): Scale factors for resizing.
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale=None,
+ scale_factor=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ assert (img_scale is None) ^ (scale_factor is None), (
+ 'Must have but only one variable can be setted')
+ if img_scale is not None:
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ self.scale_key = 'scale'
+ assert mmcv.is_list_of(self.img_scale, tuple)
+ else:
+ self.img_scale = scale_factor if isinstance(
+ scale_factor, list) else [scale_factor]
+ self.scale_key = 'scale_factor'
+
+ self.flip = flip
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ flip_args = [(False, None)]
+ if self.flip:
+ flip_args += [(True, direction)
+ for direction in self.flip_direction]
+ for scale in self.img_scale:
+ for flip, direction in flip_args:
+ _results = results.copy()
+ _results[self.scale_key] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
+ repr_str += f'flip_direction={self.flip_direction})'
+ return repr_str
diff --git a/mmdet/datasets/pipelines/transforms.py b/mmdet/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..5166fc09bd16ab7f4a5b59485fe7976bfd2dfdd2
--- /dev/null
+++ b/mmdet/datasets/pipelines/transforms.py
@@ -0,0 +1,1812 @@
+import copy
+import inspect
+
+import mmcv
+import numpy as np
+from numpy import random
+
+from mmdet.core import PolygonMasks
+from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
+from ..builder import PIPELINES
+
+try:
+ from imagecorruptions import corrupt
+except ImportError:
+ corrupt = None
+
+try:
+ import albumentations
+ from albumentations import Compose
+except ImportError:
+ albumentations = None
+ Compose = None
+
+
+@PIPELINES.register_module()
+class Resize(object):
+ """Resize images & bbox & mask.
+
+ This transform resizes the input image to some scale. Bboxes and masks are
+ then resized with the same scale factor. If the input dict contains the key
+ "scale", then the scale in the input dict is used, otherwise the specified
+ scale in the init method is used. If the input dict contains the key
+ "scale_factor" (if MultiScaleFlipAug does not give img_scale but
+ scale_factor), the actual scale will be computed by image shape and
+ scale_factor.
+
+ `img_scale` can either be a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 3 multiscale modes:
+
+ - ``ratio_range is not None``: randomly sample a ratio from the ratio \
+ range and multiply it with the image scale.
+ - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
+ sample a scale from the multiscale range.
+ - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
+ sample a scale from multiple scales.
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+ backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ override (bool, optional): Whether to override `scale` and
+ `scale_factor` so as to call resize twice. Default False. If True,
+ after the first resizing, the existed `scale` and `scale_factor`
+ will be ignored so the second resizing can be allowed.
+ This option is a work-around for multiple times of resize in DETR.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True,
+ bbox_clip_border=True,
+ backend='cv2',
+ override=False):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given a scale and a range of image ratio
+ assert len(self.img_scale) == 1
+ else:
+ # mode 2: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.backend = backend
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+ # TODO: refactor the override option in Resize
+ self.override = override
+ self.bbox_clip_border = bbox_clip_border
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
+ where ``img_scale`` is the selected image scale and \
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where \
+ ``img_scale`` is sampled scale and None is just a placeholder \
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where \
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and \
+ None is just a placeholder to be consistent with \
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into \
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ for key in results.get('img_fields', ['img']):
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ backend=self.backend)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results[key].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ backend=self.backend)
+ results[key] = img
+
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img_shape'] = img.shape
+ # in case that there is no padding
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_bboxes(self, results):
+ """Resize bounding boxes with ``results['scale_factor']``."""
+ for key in results.get('bbox_fields', []):
+ bboxes = results[key] * results['scale_factor']
+ if self.bbox_clip_border:
+ img_shape = results['img_shape']
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ results[key] = bboxes
+
+ def _resize_masks(self, results):
+ """Resize masks with ``results['scale']``"""
+ for key in results.get('mask_fields', []):
+ if results[key] is None:
+ continue
+ if self.keep_ratio:
+ results[key] = results[key].rescale(results['scale'])
+ else:
+ results[key] = results[key].resize(results['img_shape'][:2])
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ else:
+ gt_seg = mmcv.imresize(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ results['gt_semantic_seg'] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ if 'scale_factor' in results:
+ img_shape = results['img'].shape[:2]
+ scale_factor = results['scale_factor']
+ assert isinstance(scale_factor, float)
+ results['scale'] = tuple(
+ [int(x * scale_factor) for x in img_shape][::-1])
+ else:
+ self._random_scale(results)
+ else:
+ if not self.override:
+ assert 'scale_factor' not in results, (
+ 'scale and scale_factor cannot be both set.')
+ else:
+ results.pop('scale')
+ if 'scale_factor' in results:
+ results.pop('scale_factor')
+ self._random_scale(results)
+
+ self._resize_img(results)
+ self._resize_bboxes(results)
+ self._resize_masks(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(img_scale={self.img_scale}, '
+ repr_str += f'multiscale_mode={self.multiscale_mode}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'keep_ratio={self.keep_ratio}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+ """Flip the image & bbox & mask.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ When random flip is enabled, ``flip_ratio``/``direction`` can either be a
+ float/string or tuple of float/string. There are 3 flip modes:
+
+ - ``flip_ratio`` is float, ``direction`` is string: the image will be
+ ``direction``ly flipped with probability of ``flip_ratio`` .
+ E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
+ then image will be horizontally flipped with probability of 0.5.
+ - ``flip_ratio`` is float, ``direction`` is list of string: the image wil
+ be ``direction[i]``ly flipped with probability of
+ ``flip_ratio/len(direction)``.
+ E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
+ then image will be horizontally flipped with probability of 0.25,
+ vertically with probability of 0.25.
+ - ``flip_ratio`` is list of float, ``direction`` is list of string:
+ given ``len(flip_ratio) == len(direction)``, the image wil
+ be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
+ E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
+ 'vertical']``, then image will be horizontally flipped with probability
+ of 0.3, vertically with probability of 0.5
+
+ Args:
+ flip_ratio (float | list[float], optional): The flipping probability.
+ Default: None.
+ direction(str | list[str], optional): The flipping direction. Options
+ are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
+ If input is a list, the length must equal ``flip_ratio``. Each
+ element in ``flip_ratio`` indicates the flip probability of
+ corresponding direction.
+ """
+
+ def __init__(self, flip_ratio=None, direction='horizontal'):
+ if isinstance(flip_ratio, list):
+ assert mmcv.is_list_of(flip_ratio, float)
+ assert 0 <= sum(flip_ratio) <= 1
+ elif isinstance(flip_ratio, float):
+ assert 0 <= flip_ratio <= 1
+ elif flip_ratio is None:
+ pass
+ else:
+ raise ValueError('flip_ratios must be None, float, '
+ 'or list of float')
+ self.flip_ratio = flip_ratio
+
+ valid_directions = ['horizontal', 'vertical', 'diagonal']
+ if isinstance(direction, str):
+ assert direction in valid_directions
+ elif isinstance(direction, list):
+ assert mmcv.is_list_of(direction, str)
+ assert set(direction).issubset(set(valid_directions))
+ else:
+ raise ValueError('direction must be either str or list of str')
+ self.direction = direction
+
+ if isinstance(flip_ratio, list):
+ assert len(self.flip_ratio) == len(self.direction)
+
+ def bbox_flip(self, bboxes, img_shape, direction):
+ """Flip bboxes horizontally.
+
+ Args:
+ bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
+ img_shape (tuple[int]): Image shape (height, width)
+ direction (str): Flip direction. Options are 'horizontal',
+ 'vertical'.
+
+ Returns:
+ numpy.ndarray: Flipped bounding boxes.
+ """
+
+ assert bboxes.shape[-1] % 4 == 0
+ flipped = bboxes.copy()
+ if direction == 'horizontal':
+ w = img_shape[1]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ h = img_shape[0]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ elif direction == 'diagonal':
+ w = img_shape[1]
+ h = img_shape[0]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ else:
+ raise ValueError(f"Invalid flipping direction '{direction}'")
+ return flipped
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added \
+ into result dict.
+ """
+
+ if 'flip' not in results:
+ if isinstance(self.direction, list):
+ # None means non-flip
+ direction_list = self.direction + [None]
+ else:
+ # None means non-flip
+ direction_list = [self.direction, None]
+
+ if isinstance(self.flip_ratio, list):
+ non_flip_ratio = 1 - sum(self.flip_ratio)
+ flip_ratio_list = self.flip_ratio + [non_flip_ratio]
+ else:
+ non_flip_ratio = 1 - self.flip_ratio
+ # exclude non-flip
+ single_ratio = self.flip_ratio / (len(direction_list) - 1)
+ flip_ratio_list = [single_ratio] * (len(direction_list) -
+ 1) + [non_flip_ratio]
+
+ cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
+
+ results['flip'] = cur_dir is not None
+ if 'flip_direction' not in results:
+ results['flip_direction'] = cur_dir
+ if results['flip']:
+ # flip image
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ # flip bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = self.bbox_flip(results[key],
+ results['img_shape'],
+ results['flip_direction'])
+ # flip masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].flip(results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value, 0 by default.
+ """
+
+ def __init__(self, size=None, size_divisor=None, pad_val=0):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ for key in results.get('img_fields', ['img']):
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results[key], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results[key], self.size_divisor, pad_val=self.pad_val)
+ results[key] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_masks(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ pad_shape = results['pad_shape'][:2]
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].pad(pad_shape, pad_val=self.pad_val)
+
+ def _pad_seg(self, results):
+ """Pad semantic segmentation map according to
+ ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key], shape=results['pad_shape'][:2])
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+ self._pad_img(results)
+ self._pad_masks(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, '
+ repr_str += f'size_divisor={self.size_divisor}, '
+ repr_str += f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+ """Random crop the image & bboxes & masks.
+
+ The absolute `crop_size` is sampled based on `crop_type` and `image_size`,
+ then the cropped results are generated.
+
+ Args:
+ crop_size (tuple): The relative ratio or absolute pixels of
+ height and width.
+ crop_type (str, optional): one of "relative_range", "relative",
+ "absolute", "absolute_range". "relative" randomly crops
+ (h * crop_size[0], w * crop_size[1]) part from an input of size
+ (h, w). "relative_range" uniformly samples relative crop size from
+ range [crop_size[0], 1] and [crop_size[1], 1] for height and width
+ respectively. "absolute" crops from an input with absolute size
+ (crop_size[0], crop_size[1]). "absolute_range" uniformly samples
+ crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
+ in range [crop_size[0], min(w, crop_size[1])]. Default "absolute".
+ allow_negative_crop (bool, optional): Whether to allow a crop that does
+ not contain any bbox area. Default False.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ - If the image is smaller than the absolute crop size, return the
+ original image.
+ - The keys for bboxes, labels and masks must be aligned. That is,
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
+ `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
+ `gt_masks_ignore`.
+ - If the crop does not contain any gt-bbox region and
+ `allow_negative_crop` is set to False, skip this image.
+ """
+
+ def __init__(self,
+ crop_size,
+ crop_type='absolute',
+ allow_negative_crop=False,
+ bbox_clip_border=True):
+ if crop_type not in [
+ 'relative_range', 'relative', 'absolute', 'absolute_range'
+ ]:
+ raise ValueError(f'Invalid crop_type {crop_type}.')
+ if crop_type in ['absolute', 'absolute_range']:
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ assert isinstance(crop_size[0], int) and isinstance(
+ crop_size[1], int)
+ else:
+ assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
+ self.crop_size = crop_size
+ self.crop_type = crop_type
+ self.allow_negative_crop = allow_negative_crop
+ self.bbox_clip_border = bbox_clip_border
+ # The key correspondence from bboxes to labels and masks.
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def _crop_data(self, results, crop_size, allow_negative_crop):
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_size (tuple): Expected absolute size after cropping, (h, w).
+ allow_negative_crop (bool): Whether to allow a crop that does not
+ contain any bbox area. Default to False.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ margin_h = max(img.shape[0] - crop_size[0], 0)
+ margin_w = max(img.shape[1] - crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+
+ # crop the image
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ img_shape = img.shape
+ results[key] = img
+ results['img_shape'] = img_shape
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ # e.g. gt_bboxes and gt_bboxes_ignore
+ bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
+ dtype=np.float32)
+ bboxes = results[key] - bbox_offset
+ if self.bbox_clip_border:
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ # If the crop does not contain any gt-bbox area and
+ # allow_negative_crop is False, skip this image.
+ if (key == 'gt_bboxes' and not valid_inds.any()
+ and not allow_negative_crop):
+ return None
+ results[key] = bboxes[valid_inds, :]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ valid_inds.nonzero()[0]].crop(
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
+
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
+
+ return results
+
+ def _get_crop_size(self, image_size):
+ """Randomly generates the absolute crop size based on `crop_type` and
+ `image_size`.
+
+ Args:
+ image_size (tuple): (h, w).
+
+ Returns:
+ crop_size (tuple): (crop_h, crop_w) in absolute pixels.
+ """
+ h, w = image_size
+ if self.crop_type == 'absolute':
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
+ elif self.crop_type == 'absolute_range':
+ assert self.crop_size[0] <= self.crop_size[1]
+ crop_h = np.random.randint(
+ min(h, self.crop_size[0]),
+ min(h, self.crop_size[1]) + 1)
+ crop_w = np.random.randint(
+ min(w, self.crop_size[0]),
+ min(w, self.crop_size[1]) + 1)
+ return crop_h, crop_w
+ elif self.crop_type == 'relative':
+ crop_h, crop_w = self.crop_size
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+ elif self.crop_type == 'relative_range':
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+
+ def __call__(self, results):
+ """Call function to randomly crop images, bounding boxes, masks,
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ image_size = results['img'].shape[:2]
+ crop_size = self._get_crop_size(image_size)
+ results = self._crop_data(results, crop_size, self.allow_negative_crop)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'crop_type={self.crop_type}, '
+ repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ """
+
+ def __init__(self, scale_factor=1, backend='cv2'):
+ self.scale_factor = scale_factor
+ self.backend = backend
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key],
+ self.scale_factor,
+ interpolation='nearest',
+ backend=self.backend)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ 8. randomly swap channels
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ assert img.dtype == np.float32, \
+ 'PhotoMetricDistortion needs the input image of dtype np.float32,'\
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline'
+ # random brightness
+ if random.randint(2):
+ delta = random.uniform(-self.brightness_delta,
+ self.brightness_delta)
+ img += delta
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # convert color from BGR to HSV
+ img = mmcv.bgr2hsv(img)
+
+ # random saturation
+ if random.randint(2):
+ img[..., 1] *= random.uniform(self.saturation_lower,
+ self.saturation_upper)
+
+ # random hue
+ if random.randint(2):
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
+ img[..., 0][img[..., 0] > 360] -= 360
+ img[..., 0][img[..., 0] < 0] += 360
+
+ # convert color from HSV to BGR
+ img = mmcv.hsv2bgr(img)
+
+ # random contrast
+ if mode == 0:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # randomly swap channels
+ if random.randint(2):
+ img = img[..., random.permutation(3)]
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
+ repr_str += 'contrast_range='
+ repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
+ repr_str += 'saturation_range='
+ repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
+ repr_str += f'hue_delta={self.hue_delta})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Expand(object):
+ """Random expand the image & bboxes.
+
+ Randomly place the original image on a canvas of 'ratio' x original image
+ size filled with mean values. The ratio is in the range of ratio_range.
+
+ Args:
+ mean (tuple): mean value of dataset.
+ to_rgb (bool): if need to convert the order of mean to align with RGB.
+ ratio_range (tuple): range of expand ratio.
+ prob (float): probability of applying this transformation
+ """
+
+ def __init__(self,
+ mean=(0, 0, 0),
+ to_rgb=True,
+ ratio_range=(1, 4),
+ seg_ignore_label=None,
+ prob=0.5):
+ self.to_rgb = to_rgb
+ self.ratio_range = ratio_range
+ if to_rgb:
+ self.mean = mean[::-1]
+ else:
+ self.mean = mean
+ self.min_ratio, self.max_ratio = ratio_range
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+
+ def __call__(self, results):
+ """Call function to expand images, bounding boxes.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images, bounding boxes expanded
+ """
+
+ if random.uniform(0, 1) > self.prob:
+ return results
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+
+ h, w, c = img.shape
+ ratio = random.uniform(self.min_ratio, self.max_ratio)
+ # speedup expand when meets large image
+ if np.all(self.mean == self.mean[0]):
+ expand_img = np.empty((int(h * ratio), int(w * ratio), c),
+ img.dtype)
+ expand_img.fill(self.mean[0])
+ else:
+ expand_img = np.full((int(h * ratio), int(w * ratio), c),
+ self.mean,
+ dtype=img.dtype)
+ left = int(random.uniform(0, w * ratio - w))
+ top = int(random.uniform(0, h * ratio - h))
+ expand_img[top:top + h, left:left + w] = img
+
+ results['img'] = expand_img
+ # expand bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = results[key] + np.tile(
+ (left, top), 2).astype(results[key].dtype)
+
+ # expand masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].expand(
+ int(h * ratio), int(w * ratio), top, left)
+
+ # expand segs
+ for key in results.get('seg_fields', []):
+ gt_seg = results[key]
+ expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
+ self.seg_ignore_label,
+ dtype=gt_seg.dtype)
+ expand_gt_seg[top:top + h, left:left + w] = gt_seg
+ results[key] = expand_gt_seg
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class MinIoURandomCrop(object):
+ """Random crop the image & bboxes, the cropped patches have minimum IoU
+ requirement with original image & bboxes, the IoU threshold is randomly
+ selected from min_ious.
+
+ Args:
+ min_ious (tuple): minimum IoU threshold for all intersections with
+ bounding boxes
+ min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
+ where a >= min_crop_size).
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ The keys for bboxes, labels and masks should be paired. That is, \
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
+ `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
+ """
+
+ def __init__(self,
+ min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
+ min_crop_size=0.3,
+ bbox_clip_border=True):
+ # 1: return ori img
+ self.min_ious = min_ious
+ self.sample_mode = (1, *min_ious, 0)
+ self.min_crop_size = min_crop_size
+ self.bbox_clip_border = bbox_clip_border
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def __call__(self, results):
+ """Call function to crop images and bounding boxes with minimum IoU
+ constraint.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images and bounding boxes cropped, \
+ 'img_shape' key is updated.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ assert 'bbox_fields' in results
+ boxes = [results[key] for key in results['bbox_fields']]
+ boxes = np.concatenate(boxes, 0)
+ h, w, c = img.shape
+ while True:
+ mode = random.choice(self.sample_mode)
+ self.mode = mode
+ if mode == 1:
+ return results
+
+ min_iou = mode
+ for i in range(50):
+ new_w = random.uniform(self.min_crop_size * w, w)
+ new_h = random.uniform(self.min_crop_size * h, h)
+
+ # h / w in [0.5, 2]
+ if new_h / new_w < 0.5 or new_h / new_w > 2:
+ continue
+
+ left = random.uniform(w - new_w)
+ top = random.uniform(h - new_h)
+
+ patch = np.array(
+ (int(left), int(top), int(left + new_w), int(top + new_h)))
+ # Line or point crop is not allowed
+ if patch[2] == patch[0] or patch[3] == patch[1]:
+ continue
+ overlaps = bbox_overlaps(
+ patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
+ if len(overlaps) > 0 and overlaps.min() < min_iou:
+ continue
+
+ # center of boxes should inside the crop img
+ # only adjust boxes and instance masks when the gt is not empty
+ if len(overlaps) > 0:
+ # adjust boxes
+ def is_center_of_bboxes_in_patch(boxes, patch):
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = ((center[:, 0] > patch[0]) *
+ (center[:, 1] > patch[1]) *
+ (center[:, 0] < patch[2]) *
+ (center[:, 1] < patch[3]))
+ return mask
+
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ if not mask.any():
+ continue
+ for key in results.get('bbox_fields', []):
+ boxes = results[key].copy()
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ boxes = boxes[mask]
+ if self.bbox_clip_border:
+ boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
+ boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
+ boxes -= np.tile(patch[:2], 2)
+
+ results[key] = boxes
+ # labels
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][mask]
+
+ # mask fields
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ mask.nonzero()[0]].crop(patch)
+ # adjust the img no matter whether the gt is empty before crop
+ img = img[patch[1]:patch[3], patch[0]:patch[2]]
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ # seg fields
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][patch[1]:patch[3],
+ patch[0]:patch[2]]
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_ious={self.min_ious}, '
+ repr_str += f'min_crop_size={self.min_crop_size}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Corrupt(object):
+ """Corruption augmentation.
+
+ Corruption transforms implemented based on
+ `imagecorruptions `_.
+
+ Args:
+ corruption (str): Corruption name.
+ severity (int, optional): The severity of corruption. Default: 1.
+ """
+
+ def __init__(self, corruption, severity=1):
+ self.corruption = corruption
+ self.severity = severity
+
+ def __call__(self, results):
+ """Call function to corrupt image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images corrupted.
+ """
+
+ if corrupt is None:
+ raise RuntimeError('imagecorruptions is not installed')
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ results['img'] = corrupt(
+ results['img'].astype(np.uint8),
+ corruption_name=self.corruption,
+ severity=self.severity)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(corruption={self.corruption}, '
+ repr_str += f'severity={self.severity})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Albu(object):
+ """Albumentation augmentation.
+
+ Adds custom transformations from Albumentations library.
+ Please, visit `https://albumentations.readthedocs.io`
+ to get more information.
+
+ An example of ``transforms`` is as followed:
+
+ .. code-block::
+
+ [
+ dict(
+ type='ShiftScaleRotate',
+ shift_limit=0.0625,
+ scale_limit=0.0,
+ rotate_limit=0,
+ interpolation=1,
+ p=0.5),
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+
+ Args:
+ transforms (list[dict]): A list of albu transformations
+ bbox_params (dict): Bbox_params for albumentation `Compose`
+ keymap (dict): Contains {'input key':'albumentation-style key'}
+ skip_img_without_anno (bool): Whether to skip the image if no ann left
+ after aug
+ """
+
+ def __init__(self,
+ transforms,
+ bbox_params=None,
+ keymap=None,
+ update_pad_shape=False,
+ skip_img_without_anno=False):
+ if Compose is None:
+ raise RuntimeError('albumentations is not installed')
+
+ # Args will be modified later, copying it will be safer
+ transforms = copy.deepcopy(transforms)
+ if bbox_params is not None:
+ bbox_params = copy.deepcopy(bbox_params)
+ if keymap is not None:
+ keymap = copy.deepcopy(keymap)
+ self.transforms = transforms
+ self.filter_lost_elements = False
+ self.update_pad_shape = update_pad_shape
+ self.skip_img_without_anno = skip_img_without_anno
+
+ # A simple workaround to remove masks without boxes
+ if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
+ and 'filter_lost_elements' in bbox_params):
+ self.filter_lost_elements = True
+ self.origin_label_fields = bbox_params['label_fields']
+ bbox_params['label_fields'] = ['idx_mapper']
+ del bbox_params['filter_lost_elements']
+
+ self.bbox_params = (
+ self.albu_builder(bbox_params) if bbox_params else None)
+ self.aug = Compose([self.albu_builder(t) for t in self.transforms],
+ bbox_params=self.bbox_params)
+
+ if not keymap:
+ self.keymap_to_albu = {
+ 'img': 'image',
+ 'gt_masks': 'masks',
+ 'gt_bboxes': 'bboxes'
+ }
+ else:
+ self.keymap_to_albu = keymap
+ self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+ def albu_builder(self, cfg):
+ """Import a module from albumentations.
+
+ It inherits some of :func:`build_from_cfg` logic.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+
+ Returns:
+ obj: The constructed object.
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+ obj_cls = getattr(albumentations, obj_type)
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(transform)
+ for transform in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ @staticmethod
+ def mapper(d, keymap):
+ """Dictionary mapper. Renames keys according to keymap provided.
+
+ Args:
+ d (dict): old dict
+ keymap (dict): {'old_key':'new_key'}
+ Returns:
+ dict: new dict.
+ """
+
+ updated_dict = {}
+ for k, v in zip(d.keys(), d.values()):
+ new_k = keymap.get(k, k)
+ updated_dict[new_k] = d[k]
+ return updated_dict
+
+ def __call__(self, results):
+ # dict to albumentations format
+ results = self.mapper(results, self.keymap_to_albu)
+ # TODO: add bbox_fields
+ if 'bboxes' in results:
+ # to list of boxes
+ if isinstance(results['bboxes'], np.ndarray):
+ results['bboxes'] = [x for x in results['bboxes']]
+ # add pseudo-field for filtration
+ if self.filter_lost_elements:
+ results['idx_mapper'] = np.arange(len(results['bboxes']))
+
+ # TODO: Support mask structure in albu
+ if 'masks' in results:
+ if isinstance(results['masks'], PolygonMasks):
+ raise NotImplementedError(
+ 'Albu only supports BitMap masks now')
+ ori_masks = results['masks']
+ if albumentations.__version__ < '0.5':
+ results['masks'] = results['masks'].masks
+ else:
+ results['masks'] = [mask for mask in results['masks'].masks]
+
+ results = self.aug(**results)
+
+ if 'bboxes' in results:
+ if isinstance(results['bboxes'], list):
+ results['bboxes'] = np.array(
+ results['bboxes'], dtype=np.float32)
+ results['bboxes'] = results['bboxes'].reshape(-1, 4)
+
+ # filter label_fields
+ if self.filter_lost_elements:
+
+ for label in self.origin_label_fields:
+ results[label] = np.array(
+ [results[label][i] for i in results['idx_mapper']])
+ if 'masks' in results:
+ results['masks'] = np.array(
+ [results['masks'][i] for i in results['idx_mapper']])
+ results['masks'] = ori_masks.__class__(
+ results['masks'], results['image'].shape[0],
+ results['image'].shape[1])
+
+ if (not len(results['idx_mapper'])
+ and self.skip_img_without_anno):
+ return None
+
+ if 'gt_labels' in results:
+ if isinstance(results['gt_labels'], list):
+ results['gt_labels'] = np.array(results['gt_labels'])
+ results['gt_labels'] = results['gt_labels'].astype(np.int64)
+
+ # back to the original format
+ results = self.mapper(results, self.keymap_back)
+
+ # update final shape
+ if self.update_pad_shape:
+ results['pad_shape'] = results['img'].shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCenterCropPad(object):
+ """Random center crop and random around padding for CornerNet.
+
+ This operation generates randomly cropped image from the original image and
+ pads it simultaneously. Different from :class:`RandomCrop`, the output
+ shape may not equal to ``crop_size`` strictly. We choose a random value
+ from ``ratios`` and the output shape could be larger or smaller than
+ ``crop_size``. The padding operation is also different from :class:`Pad`,
+ here we use around padding instead of right-bottom padding.
+
+ The relation between output image (padding image) and original image:
+
+ .. code:: text
+
+ output image
+
+ +----------------------------+
+ | padded area |
+ +------|----------------------------|----------+
+ | | cropped area | |
+ | | +---------------+ | |
+ | | | . center | | | original image
+ | | | range | | |
+ | | +---------------+ | |
+ +------|----------------------------|----------+
+ | padded area |
+ +----------------------------+
+
+ There are 5 main areas in the figure:
+
+ - output image: output image of this operation, also called padding
+ image in following instruction.
+ - original image: input image of this operation.
+ - padded area: non-intersect area of output image and original image.
+ - cropped area: the overlap of output image and original image.
+ - center range: a smaller area where random center chosen from.
+ center range is computed by ``border`` and original image's shape
+ to avoid our random center is too close to original image's border.
+
+ Also this operation act differently in train and test mode, the summary
+ pipeline is listed below.
+
+ Train pipeline:
+
+ 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
+ will be ``random_ratio * crop_size``.
+ 2. Choose a ``random_center`` in center range.
+ 3. Generate padding image with center matches the ``random_center``.
+ 4. Initialize the padding image with pixel value equals to ``mean``.
+ 5. Copy the cropped area to padding image.
+ 6. Refine annotations.
+
+ Test pipeline:
+
+ 1. Compute output shape according to ``test_pad_mode``.
+ 2. Generate padding image with center matches the original image
+ center.
+ 3. Initialize the padding image with pixel value equals to ``mean``.
+ 4. Copy the ``cropped area`` to padding image.
+
+ Args:
+ crop_size (tuple | None): expected size after crop, final size will
+ computed according to ratio. Requires (h, w) in train mode, and
+ None in test mode.
+ ratios (tuple): random select a ratio from tuple and crop image to
+ (crop_size[0] * ratio) * (crop_size[1] * ratio).
+ Only available in train mode.
+ border (int): max distance from center select area to image border.
+ Only available in train mode.
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB.
+ test_mode (bool): whether involve random variables in transform.
+ In train mode, crop_size is fixed, center coords and ratio is
+ random selected from predefined lists. In test mode, crop_size
+ is image's original shape, center coords and ratio is fixed.
+ test_pad_mode (tuple): padding method and padding shape value, only
+ available in test mode. Default is using 'logical_or' with
+ 127 as padding shape value.
+
+ - 'logical_or': final_shape = input_shape | padding_shape_value
+ - 'size_divisor': final_shape = int(
+ ceil(input_shape / padding_shape_value) * padding_shape_value)
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ crop_size=None,
+ ratios=(0.9, 1.0, 1.1),
+ border=128,
+ mean=None,
+ std=None,
+ to_rgb=None,
+ test_mode=False,
+ test_pad_mode=('logical_or', 127),
+ bbox_clip_border=True):
+ if test_mode:
+ assert crop_size is None, 'crop_size must be None in test mode'
+ assert ratios is None, 'ratios must be None in test mode'
+ assert border is None, 'border must be None in test mode'
+ assert isinstance(test_pad_mode, (list, tuple))
+ assert test_pad_mode[0] in ['logical_or', 'size_divisor']
+ else:
+ assert isinstance(crop_size, (list, tuple))
+ assert crop_size[0] > 0 and crop_size[1] > 0, (
+ 'crop_size must > 0 in train mode')
+ assert isinstance(ratios, (list, tuple))
+ assert test_pad_mode is None, (
+ 'test_pad_mode must be None in train mode')
+
+ self.crop_size = crop_size
+ self.ratios = ratios
+ self.border = border
+ # We do not set default value to mean, std and to_rgb because these
+ # hyper-parameters are easy to forget but could affect the performance.
+ # Please use the same setting as Normalize for performance assurance.
+ assert mean is not None and std is not None and to_rgb is not None
+ self.to_rgb = to_rgb
+ self.input_mean = mean
+ self.input_std = std
+ if to_rgb:
+ self.mean = mean[::-1]
+ self.std = std[::-1]
+ else:
+ self.mean = mean
+ self.std = std
+ self.test_mode = test_mode
+ self.test_pad_mode = test_pad_mode
+ self.bbox_clip_border = bbox_clip_border
+
+ def _get_border(self, border, size):
+ """Get final border for the target size.
+
+ This function generates a ``final_border`` according to image's shape.
+ The area between ``final_border`` and ``size - final_border`` is the
+ ``center range``. We randomly choose center from the ``center range``
+ to avoid our random center is too close to original image's border.
+ Also ``center range`` should be larger than 0.
+
+ Args:
+ border (int): The initial border, default is 128.
+ size (int): The width or height of original image.
+ Returns:
+ int: The final border.
+ """
+ k = 2 * border / size
+ i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k)))
+ return border // i
+
+ def _filter_boxes(self, patch, boxes):
+ """Check whether the center of each box is in the patch.
+
+ Args:
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ boxes (numpy array, (N x 4)): Ground truth boxes.
+
+ Returns:
+ mask (numpy array, (N,)): Each box is inside or outside the patch.
+ """
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (
+ center[:, 0] < patch[2]) * (
+ center[:, 1] < patch[3])
+ return mask
+
+ def _crop_image_and_paste(self, image, center, size):
+ """Crop image with a given center and size, then paste the cropped
+ image to a blank image with two centers align.
+
+ This function is equivalent to generating a blank image with ``size``
+ as its shape. Then cover it on the original image with two centers (
+ the center of blank image and the random center of original image)
+ aligned. The overlap area is paste from the original image and the
+ outside area is filled with ``mean pixel``.
+
+ Args:
+ image (np array, H x W x C): Original image.
+ center (list[int]): Target crop center coord.
+ size (list[int]): Target crop size. [target_h, target_w]
+
+ Returns:
+ cropped_img (np array, target_h x target_w x C): Cropped image.
+ border (np array, 4): The distance of four border of
+ ``cropped_img`` to the original image area, [top, bottom,
+ left, right]
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ """
+ center_y, center_x = center
+ target_h, target_w = size
+ img_h, img_w, img_c = image.shape
+
+ x0 = max(0, center_x - target_w // 2)
+ x1 = min(center_x + target_w // 2, img_w)
+ y0 = max(0, center_y - target_h // 2)
+ y1 = min(center_y + target_h // 2, img_h)
+ patch = np.array((int(x0), int(y0), int(x1), int(y1)))
+
+ left, right = center_x - x0, x1 - center_x
+ top, bottom = center_y - y0, y1 - center_y
+
+ cropped_center_y, cropped_center_x = target_h // 2, target_w // 2
+ cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype)
+ for i in range(img_c):
+ cropped_img[:, :, i] += self.mean[i]
+ y_slice = slice(cropped_center_y - top, cropped_center_y + bottom)
+ x_slice = slice(cropped_center_x - left, cropped_center_x + right)
+ cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :]
+
+ border = np.array([
+ cropped_center_y - top, cropped_center_y + bottom,
+ cropped_center_x - left, cropped_center_x + right
+ ],
+ dtype=np.float32)
+
+ return cropped_img, border, patch
+
+ def _train_aug(self, results):
+ """Random crop and around padding the original image.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ boxes = results['gt_bboxes']
+ while True:
+ scale = random.choice(self.ratios)
+ new_h = int(self.crop_size[0] * scale)
+ new_w = int(self.crop_size[1] * scale)
+ h_border = self._get_border(self.border, h)
+ w_border = self._get_border(self.border, w)
+
+ for i in range(50):
+ center_x = random.randint(low=w_border, high=w - w_border)
+ center_y = random.randint(low=h_border, high=h - h_border)
+
+ cropped_img, border, patch = self._crop_image_and_paste(
+ img, [center_y, center_x], [new_h, new_w])
+
+ mask = self._filter_boxes(patch, boxes)
+ # if image do not have valid bbox, any crop patch is valid.
+ if not mask.any() and len(boxes) > 0:
+ continue
+
+ results['img'] = cropped_img
+ results['img_shape'] = cropped_img.shape
+ results['pad_shape'] = cropped_img.shape
+
+ x0, y0, x1, y1 = patch
+
+ left_w, top_h = center_x - x0, center_y - y0
+ cropped_center_x, cropped_center_y = new_w // 2, new_h // 2
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ mask = self._filter_boxes(patch, results[key])
+ bboxes = results[key][mask]
+ bboxes[:, 0:4:2] += cropped_center_x - left_w - x0
+ bboxes[:, 1:4:2] += cropped_center_y - top_h - y0
+ if self.bbox_clip_border:
+ bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
+ bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
+ keep = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ bboxes = bboxes[keep]
+ results[key] = bboxes
+ if key in ['gt_bboxes']:
+ if 'gt_labels' in results:
+ labels = results['gt_labels'][mask]
+ labels = labels[keep]
+ results['gt_labels'] = labels
+ if 'gt_masks' in results:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+ return results
+
+ def _test_aug(self, results):
+ """Around padding the original image without cropping.
+
+ The padding mode and value are from ``test_pad_mode``.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ results['img_shape'] = img.shape
+ if self.test_pad_mode[0] in ['logical_or']:
+ target_h = h | self.test_pad_mode[1]
+ target_w = w | self.test_pad_mode[1]
+ elif self.test_pad_mode[0] in ['size_divisor']:
+ divisor = self.test_pad_mode[1]
+ target_h = int(np.ceil(h / divisor)) * divisor
+ target_w = int(np.ceil(w / divisor)) * divisor
+ else:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only support two testing pad mode:'
+ 'logical-or and size_divisor.')
+
+ cropped_img, border, _ = self._crop_image_and_paste(
+ img, [h // 2, w // 2], [target_h, target_w])
+ results['img'] = cropped_img
+ results['pad_shape'] = cropped_img.shape
+ results['border'] = border
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ assert img.dtype == np.float32, (
+ 'RandomCenterCropPad needs the input image of dtype np.float32,'
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline')
+ h, w, c = img.shape
+ assert c == len(self.mean)
+ if self.test_mode:
+ return self._test_aug(results)
+ else:
+ return self._train_aug(results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'ratios={self.ratios}, '
+ repr_str += f'border={self.border}, '
+ repr_str += f'mean={self.input_mean}, '
+ repr_str += f'std={self.input_std}, '
+ repr_str += f'to_rgb={self.to_rgb}, '
+ repr_str += f'test_mode={self.test_mode}, '
+ repr_str += f'test_pad_mode={self.test_pad_mode}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CutOut(object):
+ """CutOut operation.
+
+ Randomly drop some regions of image used in
+ `Cutout `_.
+
+ Args:
+ n_holes (int | tuple[int, int]): Number of regions to be dropped.
+ If it is given as a list, number of holes will be randomly
+ selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
+ cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
+ shape of dropped regions. It can be `tuple[int, int]` to use a
+ fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
+ shape from the list.
+ cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
+ candidate ratio of dropped regions. It can be `tuple[float, float]`
+ to use a fixed ratio or `list[tuple[float, float]]` to randomly
+ choose ratio from the list. Please note that `cutout_shape`
+ and `cutout_ratio` cannot be both given at the same time.
+ fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
+ of pixel to fill in the dropped regions. Default: (0, 0, 0).
+ """
+
+ def __init__(self,
+ n_holes,
+ cutout_shape=None,
+ cutout_ratio=None,
+ fill_in=(0, 0, 0)):
+
+ assert (cutout_shape is None) ^ (cutout_ratio is None), \
+ 'Either cutout_shape or cutout_ratio should be specified.'
+ assert (isinstance(cutout_shape, (list, tuple))
+ or isinstance(cutout_ratio, (list, tuple)))
+ if isinstance(n_holes, tuple):
+ assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
+ else:
+ n_holes = (n_holes, n_holes)
+ self.n_holes = n_holes
+ self.fill_in = fill_in
+ self.with_ratio = cutout_ratio is not None
+ self.candidates = cutout_ratio if self.with_ratio else cutout_shape
+ if not isinstance(self.candidates, list):
+ self.candidates = [self.candidates]
+
+ def __call__(self, results):
+ """Call function to drop some regions of image."""
+ h, w, c = results['img'].shape
+ n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
+ for _ in range(n_holes):
+ x1 = np.random.randint(0, w)
+ y1 = np.random.randint(0, h)
+ index = np.random.randint(0, len(self.candidates))
+ if not self.with_ratio:
+ cutout_w, cutout_h = self.candidates[index]
+ else:
+ cutout_w = int(self.candidates[index][0] * w)
+ cutout_h = int(self.candidates[index][1] * h)
+
+ x2 = np.clip(x1 + cutout_w, 0, w)
+ y2 = np.clip(y1 + cutout_h, 0, h)
+ results['img'][y1:y2, x1:x2, :] = self.fill_in
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(n_holes={self.n_holes}, '
+ repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
+ else f'cutout_shape={self.candidates}, ')
+ repr_str += f'fill_in={self.fill_in})'
+ return repr_str
diff --git a/mmdet/datasets/samplers/__init__.py b/mmdet/datasets/samplers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2596aeb2ccfc85b58624713c04453d34e94a4062
--- /dev/null
+++ b/mmdet/datasets/samplers/__init__.py
@@ -0,0 +1,4 @@
+from .distributed_sampler import DistributedSampler
+from .group_sampler import DistributedGroupSampler, GroupSampler
+
+__all__ = ['DistributedSampler', 'DistributedGroupSampler', 'GroupSampler']
diff --git a/mmdet/datasets/samplers/distributed_sampler.py b/mmdet/datasets/samplers/distributed_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc61019484655ee2829f7908dc442caa20cf1d54
--- /dev/null
+++ b/mmdet/datasets/samplers/distributed_sampler.py
@@ -0,0 +1,39 @@
+import math
+
+import torch
+from torch.utils.data import DistributedSampler as _DistributedSampler
+
+
+class DistributedSampler(_DistributedSampler):
+
+ def __init__(self,
+ dataset,
+ num_replicas=None,
+ rank=None,
+ shuffle=True,
+ seed=0):
+ super().__init__(
+ dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
+ # for the compatibility from PyTorch 1.3+
+ self.seed = seed if seed is not None else 0
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ if self.shuffle:
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+ indices = torch.randperm(len(self.dataset), generator=g).tolist()
+ else:
+ indices = torch.arange(len(self.dataset)).tolist()
+
+ # add extra samples to make it evenly divisible
+ # in case that indices is shorter than half of total_size
+ indices = (indices *
+ math.ceil(self.total_size / len(indices)))[:self.total_size]
+ assert len(indices) == self.total_size
+
+ # subsample
+ indices = indices[self.rank:self.total_size:self.num_replicas]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
diff --git a/mmdet/datasets/samplers/group_sampler.py b/mmdet/datasets/samplers/group_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f88cf3439446a2eb7d8656388ddbe93196315f5b
--- /dev/null
+++ b/mmdet/datasets/samplers/group_sampler.py
@@ -0,0 +1,148 @@
+from __future__ import division
+import math
+
+import numpy as np
+import torch
+from mmcv.runner import get_dist_info
+from torch.utils.data import Sampler
+
+
+class GroupSampler(Sampler):
+
+ def __init__(self, dataset, samples_per_gpu=1):
+ assert hasattr(dataset, 'flag')
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.flag = dataset.flag.astype(np.int64)
+ self.group_sizes = np.bincount(self.flag)
+ self.num_samples = 0
+ for i, size in enumerate(self.group_sizes):
+ self.num_samples += int(np.ceil(
+ size / self.samples_per_gpu)) * self.samples_per_gpu
+
+ def __iter__(self):
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size == 0:
+ continue
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ np.random.shuffle(indice)
+ num_extra = int(np.ceil(size / self.samples_per_gpu)
+ ) * self.samples_per_gpu - len(indice)
+ indice = np.concatenate(
+ [indice, np.random.choice(indice, num_extra)])
+ indices.append(indice)
+ indices = np.concatenate(indices)
+ indices = [
+ indices[i * self.samples_per_gpu:(i + 1) * self.samples_per_gpu]
+ for i in np.random.permutation(
+ range(len(indices) // self.samples_per_gpu))
+ ]
+ indices = np.concatenate(indices)
+ indices = indices.astype(np.int64).tolist()
+ assert len(indices) == self.num_samples
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+
+class DistributedGroupSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ num_replicas (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within num_replicas.
+ seed (int, optional): random seed used to shuffle the sampler if
+ ``shuffle=True``. This number should be identical across all
+ processes in the distributed group. Default: 0.
+ """
+
+ def __init__(self,
+ dataset,
+ samples_per_gpu=1,
+ num_replicas=None,
+ rank=None,
+ seed=0):
+ _rank, _num_replicas = get_dist_info()
+ if num_replicas is None:
+ num_replicas = _num_replicas
+ if rank is None:
+ rank = _rank
+ self.dataset = dataset
+ self.samples_per_gpu = samples_per_gpu
+ self.num_replicas = num_replicas
+ self.rank = rank
+ self.epoch = 0
+ self.seed = seed if seed is not None else 0
+
+ assert hasattr(self.dataset, 'flag')
+ self.flag = self.dataset.flag
+ self.group_sizes = np.bincount(self.flag)
+
+ self.num_samples = 0
+ for i, j in enumerate(self.group_sizes):
+ self.num_samples += int(
+ math.ceil(self.group_sizes[i] * 1.0 / self.samples_per_gpu /
+ self.num_replicas)) * self.samples_per_gpu
+ self.total_size = self.num_samples * self.num_replicas
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch + self.seed)
+
+ indices = []
+ for i, size in enumerate(self.group_sizes):
+ if size > 0:
+ indice = np.where(self.flag == i)[0]
+ assert len(indice) == size
+ # add .numpy() to avoid bug when selecting indice in parrots.
+ # TODO: check whether torch.randperm() can be replaced by
+ # numpy.random.permutation().
+ indice = indice[list(
+ torch.randperm(int(size), generator=g).numpy())].tolist()
+ extra = int(
+ math.ceil(
+ size * 1.0 / self.samples_per_gpu / self.num_replicas)
+ ) * self.samples_per_gpu * self.num_replicas - len(indice)
+ # pad indice
+ tmp = indice.copy()
+ for _ in range(extra // size):
+ indice.extend(tmp)
+ indice.extend(tmp[:extra % size])
+ indices.extend(indice)
+
+ assert len(indices) == self.total_size
+
+ indices = [
+ indices[j] for i in list(
+ torch.randperm(
+ len(indices) // self.samples_per_gpu, generator=g))
+ for j in range(i * self.samples_per_gpu, (i + 1) *
+ self.samples_per_gpu)
+ ]
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/mmdet/datasets/utils.py b/mmdet/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..157c9a2e1fe009552fdec9b9c9e7a33ed46d51ff
--- /dev/null
+++ b/mmdet/datasets/utils.py
@@ -0,0 +1,158 @@
+import copy
+import warnings
+
+from mmcv.cnn import VGG
+from mmcv.runner.hooks import HOOKS, Hook
+
+from mmdet.datasets.builder import PIPELINES
+from mmdet.datasets.pipelines import LoadAnnotations, LoadImageFromFile
+from mmdet.models.dense_heads import GARPNHead, RPNHead
+from mmdet.models.roi_heads.mask_heads import FusedSemanticHead
+
+
+def replace_ImageToTensor(pipelines):
+ """Replace the ImageToTensor transform in a data pipeline to
+ DefaultFormatBundle, which is normally useful in batch inference.
+
+ Args:
+ pipelines (list[dict]): Data pipeline configs.
+
+ Returns:
+ list: The new pipeline list with all ImageToTensor replaced by
+ DefaultFormatBundle.
+
+ Examples:
+ >>> pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(
+ ... type='MultiScaleFlipAug',
+ ... img_scale=(1333, 800),
+ ... flip=False,
+ ... transforms=[
+ ... dict(type='Resize', keep_ratio=True),
+ ... dict(type='RandomFlip'),
+ ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='ImageToTensor', keys=['img']),
+ ... dict(type='Collect', keys=['img']),
+ ... ])
+ ... ]
+ >>> expected_pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(
+ ... type='MultiScaleFlipAug',
+ ... img_scale=(1333, 800),
+ ... flip=False,
+ ... transforms=[
+ ... dict(type='Resize', keep_ratio=True),
+ ... dict(type='RandomFlip'),
+ ... dict(type='Normalize', mean=[0, 0, 0], std=[1, 1, 1]),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='DefaultFormatBundle'),
+ ... dict(type='Collect', keys=['img']),
+ ... ])
+ ... ]
+ >>> assert expected_pipelines == replace_ImageToTensor(pipelines)
+ """
+ pipelines = copy.deepcopy(pipelines)
+ for i, pipeline in enumerate(pipelines):
+ if pipeline['type'] == 'MultiScaleFlipAug':
+ assert 'transforms' in pipeline
+ pipeline['transforms'] = replace_ImageToTensor(
+ pipeline['transforms'])
+ elif pipeline['type'] == 'ImageToTensor':
+ warnings.warn(
+ '"ImageToTensor" pipeline is replaced by '
+ '"DefaultFormatBundle" for batch inference. It is '
+ 'recommended to manually replace it in the test '
+ 'data pipeline in your config file.', UserWarning)
+ pipelines[i] = {'type': 'DefaultFormatBundle'}
+ return pipelines
+
+
+def get_loading_pipeline(pipeline):
+ """Only keep loading image and annotations related configuration.
+
+ Args:
+ pipeline (list[dict]): Data pipeline configs.
+
+ Returns:
+ list[dict]: The new pipeline list with only keep
+ loading image and annotations related configuration.
+
+ Examples:
+ >>> pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(type='LoadAnnotations', with_bbox=True),
+ ... dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
+ ... dict(type='RandomFlip', flip_ratio=0.5),
+ ... dict(type='Normalize', **img_norm_cfg),
+ ... dict(type='Pad', size_divisor=32),
+ ... dict(type='DefaultFormatBundle'),
+ ... dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels'])
+ ... ]
+ >>> expected_pipelines = [
+ ... dict(type='LoadImageFromFile'),
+ ... dict(type='LoadAnnotations', with_bbox=True)
+ ... ]
+ >>> assert expected_pipelines ==\
+ ... get_loading_pipeline(pipelines)
+ """
+ loading_pipeline_cfg = []
+ for cfg in pipeline:
+ obj_cls = PIPELINES.get(cfg['type'])
+ # TODO:use more elegant way to distinguish loading modules
+ if obj_cls is not None and obj_cls in (LoadImageFromFile,
+ LoadAnnotations):
+ loading_pipeline_cfg.append(cfg)
+ assert len(loading_pipeline_cfg) == 2, \
+ 'The data pipeline in your config file must include ' \
+ 'loading image and annotations related pipeline.'
+ return loading_pipeline_cfg
+
+
+@HOOKS.register_module()
+class NumClassCheckHook(Hook):
+
+ def _check_head(self, runner):
+ """Check whether the `num_classes` in head matches the length of
+ `CLASSSES` in `dataset`.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ model = runner.model
+ dataset = runner.data_loader.dataset
+ if dataset.CLASSES is None:
+ runner.logger.warning(
+ f'Please set `CLASSES` '
+ f'in the {dataset.__class__.__name__} and'
+ f'check if it is consistent with the `num_classes` '
+ f'of head')
+ else:
+ for name, module in model.named_modules():
+ if hasattr(module, 'num_classes') and not isinstance(
+ module, (RPNHead, VGG, FusedSemanticHead, GARPNHead)):
+ assert module.num_classes == len(dataset.CLASSES), \
+ (f'The `num_classes` ({module.num_classes}) in '
+ f'{module.__class__.__name__} of '
+ f'{model.__class__.__name__} does not matches '
+ f'the length of `CLASSES` '
+ f'{len(dataset.CLASSES)}) in '
+ f'{dataset.__class__.__name__}')
+
+ def before_train_epoch(self, runner):
+ """Check whether the training dataset is compatible with head.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ self._check_head(runner)
+
+ def before_val_epoch(self, runner):
+ """Check whether the dataset in val epoch is compatible with head.
+
+ Args:
+ runner (obj:`EpochBasedRunner`): Epoch based Runner.
+ """
+ self._check_head(runner)
diff --git a/mmdet/datasets/voc.py b/mmdet/datasets/voc.py
new file mode 100644
index 0000000000000000000000000000000000000000..abd4cb8947238936faff48fc92c093c8ae06daff
--- /dev/null
+++ b/mmdet/datasets/voc.py
@@ -0,0 +1,93 @@
+from collections import OrderedDict
+
+from mmcv.utils import print_log
+
+from mmdet.core import eval_map, eval_recalls
+from .builder import DATASETS
+from .xml_style import XMLDataset
+
+
+@DATASETS.register_module()
+class VOCDataset(XMLDataset):
+
+ CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car',
+ 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse',
+ 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train',
+ 'tvmonitor')
+
+ def __init__(self, **kwargs):
+ super(VOCDataset, self).__init__(**kwargs)
+ if 'VOC2007' in self.img_prefix:
+ self.year = 2007
+ elif 'VOC2012' in self.img_prefix:
+ self.year = 2012
+ else:
+ raise ValueError('Cannot infer dataset year from img_prefix')
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ proposal_nums=(100, 300, 1000),
+ iou_thr=0.5,
+ scale_ranges=None):
+ """Evaluate in VOC protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'mAP', 'recall'.
+ logger (logging.Logger | str, optional): Logger used for printing
+ related information during evaluation. Default: None.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ scale_ranges (list[tuple], optional): Scale ranges for evaluating
+ mAP. If not specified, all bounding boxes would be included in
+ evaluation. Default: None.
+
+ Returns:
+ dict[str, float]: AP/recall metrics.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP', 'recall']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list)
+ if self.year == 2007:
+ ds_name = 'voc07'
+ else:
+ ds_name = self.CLASSES
+ mean_aps = []
+ for iou_thr in iou_thrs:
+ print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=None,
+ iou_thr=iou_thr,
+ dataset=ds_name,
+ logger=logger)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ elif metric == 'recall':
+ gt_bboxes = [ann['bboxes'] for ann in annotations]
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
+ for i, num in enumerate(proposal_nums):
+ for j, iou in enumerate(iou_thr):
+ eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
+ if recalls.shape[1] > 1:
+ ar = recalls.mean(axis=1)
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ return eval_results
diff --git a/mmdet/datasets/wider_face.py b/mmdet/datasets/wider_face.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a13907db87a9986a7d701837259a0b712fc9dca
--- /dev/null
+++ b/mmdet/datasets/wider_face.py
@@ -0,0 +1,51 @@
+import os.path as osp
+import xml.etree.ElementTree as ET
+
+import mmcv
+
+from .builder import DATASETS
+from .xml_style import XMLDataset
+
+
+@DATASETS.register_module()
+class WIDERFaceDataset(XMLDataset):
+ """Reader for the WIDER Face dataset in PASCAL VOC format.
+
+ Conversion scripts can be found in
+ https://github.com/sovrasov/wider-face-pascal-voc-annotations
+ """
+ CLASSES = ('face', )
+
+ def __init__(self, **kwargs):
+ super(WIDERFaceDataset, self).__init__(**kwargs)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from WIDERFace XML style annotation file.
+
+ Args:
+ ann_file (str): Path of XML file.
+
+ Returns:
+ list[dict]: Annotation info from XML file.
+ """
+
+ data_infos = []
+ img_ids = mmcv.list_from_file(ann_file)
+ for img_id in img_ids:
+ filename = f'{img_id}.jpg'
+ xml_path = osp.join(self.img_prefix, 'Annotations',
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ size = root.find('size')
+ width = int(size.find('width').text)
+ height = int(size.find('height').text)
+ folder = root.find('folder').text
+ data_infos.append(
+ dict(
+ id=img_id,
+ filename=osp.join(folder, filename),
+ width=width,
+ height=height))
+
+ return data_infos
diff --git a/mmdet/datasets/xml_style.py b/mmdet/datasets/xml_style.py
new file mode 100644
index 0000000000000000000000000000000000000000..71069488b0f6da3b37e588228f44460ce5f00679
--- /dev/null
+++ b/mmdet/datasets/xml_style.py
@@ -0,0 +1,170 @@
+import os.path as osp
+import xml.etree.ElementTree as ET
+
+import mmcv
+import numpy as np
+from PIL import Image
+
+from .builder import DATASETS
+from .custom import CustomDataset
+
+
+@DATASETS.register_module()
+class XMLDataset(CustomDataset):
+ """XML dataset for detection.
+
+ Args:
+ min_size (int | float, optional): The minimum size of bounding
+ boxes in the images. If the size of a bounding box is less than
+ ``min_size``, it would be add to ignored field.
+ """
+
+ def __init__(self, min_size=None, **kwargs):
+ assert self.CLASSES or kwargs.get(
+ 'classes', None), 'CLASSES in `XMLDataset` can not be None.'
+ super(XMLDataset, self).__init__(**kwargs)
+ self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)}
+ self.min_size = min_size
+
+ def load_annotations(self, ann_file):
+ """Load annotation from XML style ann_file.
+
+ Args:
+ ann_file (str): Path of XML file.
+
+ Returns:
+ list[dict]: Annotation info from XML file.
+ """
+
+ data_infos = []
+ img_ids = mmcv.list_from_file(ann_file)
+ for img_id in img_ids:
+ filename = f'JPEGImages/{img_id}.jpg'
+ xml_path = osp.join(self.img_prefix, 'Annotations',
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ size = root.find('size')
+ if size is not None:
+ width = int(size.find('width').text)
+ height = int(size.find('height').text)
+ else:
+ img_path = osp.join(self.img_prefix, 'JPEGImages',
+ '{}.jpg'.format(img_id))
+ img = Image.open(img_path)
+ width, height = img.size
+ data_infos.append(
+ dict(id=img_id, filename=filename, width=width, height=height))
+
+ return data_infos
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without annotation."""
+ valid_inds = []
+ for i, img_info in enumerate(self.data_infos):
+ if min(img_info['width'], img_info['height']) < min_size:
+ continue
+ if self.filter_empty_gt:
+ img_id = img_info['id']
+ xml_path = osp.join(self.img_prefix, 'Annotations',
+ f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name in self.CLASSES:
+ valid_inds.append(i)
+ break
+ else:
+ valid_inds.append(i)
+ return valid_inds
+
+ def get_ann_info(self, idx):
+ """Get annotation from XML file by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ bboxes = []
+ labels = []
+ bboxes_ignore = []
+ labels_ignore = []
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name not in self.CLASSES:
+ continue
+ label = self.cat2label[name]
+ difficult = obj.find('difficult')
+ difficult = 0 if difficult is None else int(difficult.text)
+ bnd_box = obj.find('bndbox')
+ # TODO: check whether it is necessary to use int
+ # Coordinates may be float type
+ bbox = [
+ int(float(bnd_box.find('xmin').text)),
+ int(float(bnd_box.find('ymin').text)),
+ int(float(bnd_box.find('xmax').text)),
+ int(float(bnd_box.find('ymax').text))
+ ]
+ ignore = False
+ if self.min_size:
+ assert not self.test_mode
+ w = bbox[2] - bbox[0]
+ h = bbox[3] - bbox[1]
+ if w < self.min_size or h < self.min_size:
+ ignore = True
+ if difficult or ignore:
+ bboxes_ignore.append(bbox)
+ labels_ignore.append(label)
+ else:
+ bboxes.append(bbox)
+ labels.append(label)
+ if not bboxes:
+ bboxes = np.zeros((0, 4))
+ labels = np.zeros((0, ))
+ else:
+ bboxes = np.array(bboxes, ndmin=2) - 1
+ labels = np.array(labels)
+ if not bboxes_ignore:
+ bboxes_ignore = np.zeros((0, 4))
+ labels_ignore = np.zeros((0, ))
+ else:
+ bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1
+ labels_ignore = np.array(labels_ignore)
+ ann = dict(
+ bboxes=bboxes.astype(np.float32),
+ labels=labels.astype(np.int64),
+ bboxes_ignore=bboxes_ignore.astype(np.float32),
+ labels_ignore=labels_ignore.astype(np.int64))
+ return ann
+
+ def get_cat_ids(self, idx):
+ """Get category ids in XML file by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ cat_ids = []
+ img_id = self.data_infos[idx]['id']
+ xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml')
+ tree = ET.parse(xml_path)
+ root = tree.getroot()
+ for obj in root.findall('object'):
+ name = obj.find('name').text
+ if name not in self.CLASSES:
+ continue
+ label = self.cat2label[name]
+ cat_ids.append(label)
+
+ return cat_ids
diff --git a/mmdet/models/__init__.py b/mmdet/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..44ac99855ae52101c91be167fa78d8219fc47259
--- /dev/null
+++ b/mmdet/models/__init__.py
@@ -0,0 +1,16 @@
+from .backbones import * # noqa: F401,F403
+from .builder import (BACKBONES, DETECTORS, HEADS, LOSSES, NECKS,
+ ROI_EXTRACTORS, SHARED_HEADS, build_backbone,
+ build_detector, build_head, build_loss, build_neck,
+ build_roi_extractor, build_shared_head)
+from .dense_heads import * # noqa: F401,F403
+from .detectors import * # noqa: F401,F403
+from .losses import * # noqa: F401,F403
+from .necks import * # noqa: F401,F403
+from .roi_heads import * # noqa: F401,F403
+
+__all__ = [
+ 'BACKBONES', 'NECKS', 'ROI_EXTRACTORS', 'SHARED_HEADS', 'HEADS', 'LOSSES',
+ 'DETECTORS', 'build_backbone', 'build_neck', 'build_roi_extractor',
+ 'build_shared_head', 'build_head', 'build_loss', 'build_detector'
+]
diff --git a/mmdet/models/backbones/__init__.py b/mmdet/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..11d7de7543b04e7040facb4472121e5c0f02ecaa
--- /dev/null
+++ b/mmdet/models/backbones/__init__.py
@@ -0,0 +1,3 @@
+from .swin_transformer import SwinTransformer
+from .resnet import ResNet, ResNetV1d
+__all__ = ['SwinTransformer', 'ResNet', 'ResNetV1d']
diff --git a/mmdet/models/backbones/darknet.py b/mmdet/models/backbones/darknet.py
new file mode 100644
index 0000000000000000000000000000000000000000..517fe26259217792e0dad80ca3824d914cfe3904
--- /dev/null
+++ b/mmdet/models/backbones/darknet.py
@@ -0,0 +1,199 @@
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import logging
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, constant_init, kaiming_init
+from mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from ..builder import BACKBONES
+
+
+class ResBlock(nn.Module):
+ """The basic residual block used in Darknet. Each ResBlock consists of two
+ ConvModules and the input is added to the final output. Each ConvModule is
+ composed of Conv, BN, and LeakyReLU. In YoloV3 paper, the first convLayer
+ has half of the number of the filters as much as the second convLayer. The
+ first convLayer has filter size of 1x1 and the second one has the filter
+ size of 3x3.
+
+ Args:
+ in_channels (int): The input channels. Must be even.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ """
+
+ def __init__(self,
+ in_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
+ super(ResBlock, self).__init__()
+ assert in_channels % 2 == 0 # ensure the in_channels is even
+ half_in_channels = in_channels // 2
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ self.conv1 = ConvModule(in_channels, half_in_channels, 1, **cfg)
+ self.conv2 = ConvModule(
+ half_in_channels, in_channels, 3, padding=1, **cfg)
+
+ def forward(self, x):
+ residual = x
+ out = self.conv1(x)
+ out = self.conv2(out)
+ out = out + residual
+
+ return out
+
+
+@BACKBONES.register_module()
+class Darknet(nn.Module):
+ """Darknet backbone.
+
+ Args:
+ depth (int): Depth of Darknet. Currently only support 53.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters. Default: -1.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+
+ Example:
+ >>> from mmdet.models import Darknet
+ >>> import torch
+ >>> self = Darknet(depth=53)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 416, 416)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ ...
+ (1, 256, 52, 52)
+ (1, 512, 26, 26)
+ (1, 1024, 13, 13)
+ """
+
+ # Dict(depth: (layers, channels))
+ arch_settings = {
+ 53: ((1, 2, 8, 8, 4), ((32, 64), (64, 128), (128, 256), (256, 512),
+ (512, 1024)))
+ }
+
+ def __init__(self,
+ depth=53,
+ out_indices=(3, 4, 5),
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ norm_eval=True):
+ super(Darknet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for darknet')
+ self.depth = depth
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+ self.layers, self.channels = self.arch_settings[depth]
+
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ self.conv1 = ConvModule(3, 32, 3, padding=1, **cfg)
+
+ self.cr_blocks = ['conv1']
+ for i, n_layers in enumerate(self.layers):
+ layer_name = f'conv_res_block{i + 1}'
+ in_c, out_c = self.channels[i]
+ self.add_module(
+ layer_name,
+ self.make_conv_res_block(in_c, out_c, n_layers, **cfg))
+ self.cr_blocks.append(layer_name)
+
+ self.norm_eval = norm_eval
+
+ def forward(self, x):
+ outs = []
+ for i, layer_name in enumerate(self.cr_blocks):
+ cr_block = getattr(self, layer_name)
+ x = cr_block(x)
+ if i in self.out_indices:
+ outs.append(x)
+
+ return tuple(outs)
+
+ def init_weights(self, pretrained=None):
+ if isinstance(pretrained, str):
+ logger = logging.getLogger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ for i in range(self.frozen_stages):
+ m = getattr(self, self.cr_blocks[i])
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def train(self, mode=True):
+ super(Darknet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+ @staticmethod
+ def make_conv_res_block(in_channels,
+ out_channels,
+ res_repeat,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU',
+ negative_slope=0.1)):
+ """In Darknet backbone, ConvLayer is usually followed by ResBlock. This
+ function will make that. The Conv layers always have 3x3 filters with
+ stride=2. The number of the filters in Conv layer is the same as the
+ out channels of the ResBlock.
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ res_repeat (int): The number of ResBlocks.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ """
+
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ model = nn.Sequential()
+ model.add_module(
+ 'conv',
+ ConvModule(
+ in_channels, out_channels, 3, stride=2, padding=1, **cfg))
+ for idx in range(res_repeat):
+ model.add_module('res{}'.format(idx),
+ ResBlock(out_channels, **cfg))
+ return model
diff --git a/mmdet/models/backbones/detectors_resnet.py b/mmdet/models/backbones/detectors_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..519db464493c7c7b60fc34be1d21add2235ec341
--- /dev/null
+++ b/mmdet/models/backbones/detectors_resnet.py
@@ -0,0 +1,305 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer, constant_init
+
+from ..builder import BACKBONES
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ r"""Bottleneck for the ResNet backbone in `DetectoRS
+ `_.
+
+ This bottleneck allows the users to specify whether to use
+ SAC (Switchable Atrous Convolution) and RFP (Recursive Feature Pyramid).
+
+ Args:
+ inplanes (int): The number of input channels.
+ planes (int): The number of output channels before expansion.
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ sac (dict, optional): Dictionary to construct SAC. Default: None.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ rfp_inplanes=None,
+ sac=None,
+ **kwargs):
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ assert sac is None or isinstance(sac, dict)
+ self.sac = sac
+ self.with_sac = sac is not None
+ if self.with_sac:
+ self.conv2 = build_conv_layer(
+ self.sac,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False)
+
+ self.rfp_inplanes = rfp_inplanes
+ if self.rfp_inplanes:
+ self.rfp_conv = build_conv_layer(
+ None,
+ self.rfp_inplanes,
+ planes * self.expansion,
+ 1,
+ stride=1,
+ bias=True)
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize the weights."""
+ if self.rfp_inplanes:
+ constant_init(self.rfp_conv, 0)
+
+ def rfp_forward(self, x, rfp_feat):
+ """The forward function that also takes the RFP features as input."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ if self.rfp_inplanes:
+ rfp_feat = self.rfp_conv(rfp_feat)
+ out = out + rfp_feat
+
+ out = self.relu(out)
+
+ return out
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone for RPF in detectoRS.
+
+ The difference between this module and base class is that we pass
+ ``rfp_inplanes`` to the first block.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ rfp_inplanes=None,
+ **kwargs):
+ self.block = block
+ assert downsample_first, f'downsample_first={downsample_first} is ' \
+ 'not supported in DetectoRS'
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down and stride != 1:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ rfp_inplanes=rfp_inplanes,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ super(ResLayer, self).__init__(*layers)
+
+
+@BACKBONES.register_module()
+class DetectoRS_ResNet(ResNet):
+ """ResNet backbone for DetectoRS.
+
+ Args:
+ sac (dict, optional): Dictionary to construct SAC (Switchable Atrous
+ Convolution). Default: None.
+ stage_with_sac (list): Which stage to use sac. Default: (False, False,
+ False, False).
+ rfp_inplanes (int, optional): The number of channels from RFP.
+ Default: None. If specified, an additional conv layer will be
+ added for ``rfp_feat``. Otherwise, the structure is the same as
+ base class.
+ output_img (bool): If ``True``, the input image will be inserted into
+ the starting position of output. Default: False.
+ pretrained (str, optional): The pretrained model to load.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ sac=None,
+ stage_with_sac=(False, False, False, False),
+ rfp_inplanes=None,
+ output_img=False,
+ pretrained=None,
+ **kwargs):
+ self.sac = sac
+ self.stage_with_sac = stage_with_sac
+ self.rfp_inplanes = rfp_inplanes
+ self.output_img = output_img
+ self.pretrained = pretrained
+ super(DetectoRS_ResNet, self).__init__(**kwargs)
+
+ self.inplanes = self.stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ sac = self.sac if self.stage_with_sac[i] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins, i)
+ else:
+ stage_plugins = None
+ planes = self.base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ sac=sac,
+ rfp_inplanes=rfp_inplanes if i > 0 else None,
+ plugins=stage_plugins)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer`` for DetectoRS."""
+ return ResLayer(**kwargs)
+
+ def forward(self, x):
+ """Forward function."""
+ outs = list(super(DetectoRS_ResNet, self).forward(x))
+ if self.output_img:
+ outs.insert(0, x)
+ return tuple(outs)
+
+ def rfp_forward(self, x, rfp_feats):
+ """Forward function for RFP."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ rfp_feat = rfp_feats[i] if i > 0 else None
+ for layer in res_layer:
+ x = layer.rfp_forward(x, rfp_feat)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmdet/models/backbones/detectors_resnext.py b/mmdet/models/backbones/detectors_resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..57d032fe37ed82d5ba24e761bdc014cc0ee5ac64
--- /dev/null
+++ b/mmdet/models/backbones/detectors_resnext.py
@@ -0,0 +1,122 @@
+import math
+
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from .detectors_resnet import Bottleneck as _Bottleneck
+from .detectors_resnet import DetectoRS_ResNet
+
+
+class Bottleneck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_sac:
+ self.conv2 = build_conv_layer(
+ self.sac,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ elif not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+
+@BACKBONES.register_module()
+class DetectoRS_ResNeXt(DetectoRS_ResNet):
+ """ResNeXt backbone for DetectoRS.
+
+ Args:
+ groups (int): The number of groups in ResNeXt.
+ base_width (int): The base width of ResNeXt.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(DetectoRS_ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return super().make_res_layer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmdet/models/backbones/hourglass.py b/mmdet/models/backbones/hourglass.py
new file mode 100644
index 0000000000000000000000000000000000000000..3422acee35e3c6f8731cdb310f188e671b5be12f
--- /dev/null
+++ b/mmdet/models/backbones/hourglass.py
@@ -0,0 +1,198 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import BasicBlock
+
+
+class HourglassModule(nn.Module):
+ """Hourglass Module for HourglassNet backbone.
+
+ Generate module recursively and use BasicBlock as the base unit.
+
+ Args:
+ depth (int): Depth of current HourglassModule.
+ stage_channels (list[int]): Feature channels of sub-modules in current
+ and follow-up HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in current and
+ follow-up HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ depth,
+ stage_channels,
+ stage_blocks,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HourglassModule, self).__init__()
+
+ self.depth = depth
+
+ cur_block = stage_blocks[0]
+ next_block = stage_blocks[1]
+
+ cur_channel = stage_channels[0]
+ next_channel = stage_channels[1]
+
+ self.up1 = ResLayer(
+ BasicBlock, cur_channel, cur_channel, cur_block, norm_cfg=norm_cfg)
+
+ self.low1 = ResLayer(
+ BasicBlock,
+ cur_channel,
+ next_channel,
+ cur_block,
+ stride=2,
+ norm_cfg=norm_cfg)
+
+ if self.depth > 1:
+ self.low2 = HourglassModule(depth - 1, stage_channels[1:],
+ stage_blocks[1:])
+ else:
+ self.low2 = ResLayer(
+ BasicBlock,
+ next_channel,
+ next_channel,
+ next_block,
+ norm_cfg=norm_cfg)
+
+ self.low3 = ResLayer(
+ BasicBlock,
+ next_channel,
+ cur_channel,
+ cur_block,
+ norm_cfg=norm_cfg,
+ downsample_first=False)
+
+ self.up2 = nn.Upsample(scale_factor=2)
+
+ def forward(self, x):
+ """Forward function."""
+ up1 = self.up1(x)
+ low1 = self.low1(x)
+ low2 = self.low2(low1)
+ low3 = self.low3(low2)
+ up2 = self.up2(low3)
+ return up1 + up2
+
+
+@BACKBONES.register_module()
+class HourglassNet(nn.Module):
+ """HourglassNet backbone.
+
+ Stacked Hourglass Networks for Human Pose Estimation.
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ downsample_times (int): Downsample times in a HourglassModule.
+ num_stacks (int): Number of HourglassModule modules stacked,
+ 1 for Hourglass-52, 2 for Hourglass-104.
+ stage_channels (list[int]): Feature channel of each sub-module in a
+ HourglassModule.
+ stage_blocks (list[int]): Number of sub-modules stacked in a
+ HourglassModule.
+ feat_channel (int): Feature channel of conv after a HourglassModule.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+
+ Example:
+ >>> from mmdet.models import HourglassNet
+ >>> import torch
+ >>> self = HourglassNet()
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 511, 511)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_output in level_outputs:
+ ... print(tuple(level_output.shape))
+ (1, 256, 128, 128)
+ (1, 256, 128, 128)
+ """
+
+ def __init__(self,
+ downsample_times=5,
+ num_stacks=2,
+ stage_channels=(256, 256, 384, 384, 384, 512),
+ stage_blocks=(2, 2, 2, 2, 2, 4),
+ feat_channel=256,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(HourglassNet, self).__init__()
+
+ self.num_stacks = num_stacks
+ assert self.num_stacks >= 1
+ assert len(stage_channels) == len(stage_blocks)
+ assert len(stage_channels) > downsample_times
+
+ cur_channel = stage_channels[0]
+
+ self.stem = nn.Sequential(
+ ConvModule(3, 128, 7, padding=3, stride=2, norm_cfg=norm_cfg),
+ ResLayer(BasicBlock, 128, 256, 1, stride=2, norm_cfg=norm_cfg))
+
+ self.hourglass_modules = nn.ModuleList([
+ HourglassModule(downsample_times, stage_channels, stage_blocks)
+ for _ in range(num_stacks)
+ ])
+
+ self.inters = ResLayer(
+ BasicBlock,
+ cur_channel,
+ cur_channel,
+ num_stacks - 1,
+ norm_cfg=norm_cfg)
+
+ self.conv1x1s = nn.ModuleList([
+ ConvModule(
+ cur_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.out_convs = nn.ModuleList([
+ ConvModule(
+ cur_channel, feat_channel, 3, padding=1, norm_cfg=norm_cfg)
+ for _ in range(num_stacks)
+ ])
+
+ self.remap_convs = nn.ModuleList([
+ ConvModule(
+ feat_channel, cur_channel, 1, norm_cfg=norm_cfg, act_cfg=None)
+ for _ in range(num_stacks - 1)
+ ])
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def init_weights(self, pretrained=None):
+ """Init module weights.
+
+ We do nothing in this function because all modules we used
+ (ConvModule, BasicBlock and etc.) have default initialization, and
+ currently we don't provide pretrained model of HourglassNet.
+
+ Detector's __init__() will call backbone's init_weights() with
+ pretrained as input, so we keep this function.
+ """
+ # Training Centripetal Model needs to reset parameters for Conv2d
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ m.reset_parameters()
+
+ def forward(self, x):
+ """Forward function."""
+ inter_feat = self.stem(x)
+ out_feats = []
+
+ for ind in range(self.num_stacks):
+ single_hourglass = self.hourglass_modules[ind]
+ out_conv = self.out_convs[ind]
+
+ hourglass_feat = single_hourglass(inter_feat)
+ out_feat = out_conv(hourglass_feat)
+ out_feats.append(out_feat)
+
+ if ind < self.num_stacks - 1:
+ inter_feat = self.conv1x1s[ind](
+ inter_feat) + self.remap_convs[ind](
+ out_feat)
+ inter_feat = self.inters[ind](self.relu(inter_feat))
+
+ return out_feats
diff --git a/mmdet/models/backbones/hrnet.py b/mmdet/models/backbones/hrnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0fd0a974192231506aa68b1e1719f618b78a1b3
--- /dev/null
+++ b/mmdet/models/backbones/hrnet.py
@@ -0,0 +1,537 @@
+import torch.nn as nn
+from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import BasicBlock, Bottleneck
+
+
+class HRModule(nn.Module):
+ """High-Resolution Module for HRNet.
+
+ In this module, every branch has 4 BasicBlocks/Bottlenecks. Fusion/Exchange
+ is in this module.
+ """
+
+ def __init__(self,
+ num_branches,
+ blocks,
+ num_blocks,
+ in_channels,
+ num_channels,
+ multiscale_output=True,
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN')):
+ super(HRModule, self).__init__()
+ self._check_branches(num_branches, num_blocks, in_channels,
+ num_channels)
+
+ self.in_channels = in_channels
+ self.num_branches = num_branches
+
+ self.multiscale_output = multiscale_output
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+ self.with_cp = with_cp
+ self.branches = self._make_branches(num_branches, blocks, num_blocks,
+ num_channels)
+ self.fuse_layers = self._make_fuse_layers()
+ self.relu = nn.ReLU(inplace=False)
+
+ def _check_branches(self, num_branches, num_blocks, in_channels,
+ num_channels):
+ if num_branches != len(num_blocks):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_BLOCKS({len(num_blocks)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(num_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_CHANNELS({len(num_channels)})'
+ raise ValueError(error_msg)
+
+ if num_branches != len(in_channels):
+ error_msg = f'NUM_BRANCHES({num_branches}) ' \
+ f'!= NUM_INCHANNELS({len(in_channels)})'
+ raise ValueError(error_msg)
+
+ def _make_one_branch(self,
+ branch_index,
+ block,
+ num_blocks,
+ num_channels,
+ stride=1):
+ downsample = None
+ if stride != 1 or \
+ self.in_channels[branch_index] != \
+ num_channels[branch_index] * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ self.in_channels[branch_index],
+ num_channels[branch_index] * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, num_channels[branch_index] *
+ block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ self.in_channels[branch_index] = \
+ num_channels[branch_index] * block.expansion
+ for i in range(1, num_blocks[branch_index]):
+ layers.append(
+ block(
+ self.in_channels[branch_index],
+ num_channels[branch_index],
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_branches(self, num_branches, block, num_blocks, num_channels):
+ branches = []
+
+ for i in range(num_branches):
+ branches.append(
+ self._make_one_branch(i, block, num_blocks, num_channels))
+
+ return nn.ModuleList(branches)
+
+ def _make_fuse_layers(self):
+ if self.num_branches == 1:
+ return None
+
+ num_branches = self.num_branches
+ in_channels = self.in_channels
+ fuse_layers = []
+ num_out_branches = num_branches if self.multiscale_output else 1
+ for i in range(num_out_branches):
+ fuse_layer = []
+ for j in range(num_branches):
+ if j > i:
+ fuse_layer.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=False),
+ build_norm_layer(self.norm_cfg, in_channels[i])[1],
+ nn.Upsample(
+ scale_factor=2**(j - i), mode='nearest')))
+ elif j == i:
+ fuse_layer.append(None)
+ else:
+ conv_downsamples = []
+ for k in range(i - j):
+ if k == i - j - 1:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[i],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[i])[1]))
+ else:
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels[j],
+ in_channels[j],
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ in_channels[j])[1],
+ nn.ReLU(inplace=False)))
+ fuse_layer.append(nn.Sequential(*conv_downsamples))
+ fuse_layers.append(nn.ModuleList(fuse_layer))
+
+ return nn.ModuleList(fuse_layers)
+
+ def forward(self, x):
+ """Forward function."""
+ if self.num_branches == 1:
+ return [self.branches[0](x[0])]
+
+ for i in range(self.num_branches):
+ x[i] = self.branches[i](x[i])
+
+ x_fuse = []
+ for i in range(len(self.fuse_layers)):
+ y = 0
+ for j in range(self.num_branches):
+ if i == j:
+ y += x[j]
+ else:
+ y += self.fuse_layers[i][j](x[j])
+ x_fuse.append(self.relu(y))
+ return x_fuse
+
+
+@BACKBONES.register_module()
+class HRNet(nn.Module):
+ """HRNet backbone.
+
+ High-Resolution Representations for Labeling Pixels and Regions
+ arXiv: https://arxiv.org/abs/1904.04514
+
+ Args:
+ extra (dict): detailed configuration for each stage of HRNet.
+ in_channels (int): Number of input image channels. Default: 3.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from mmdet.models import HRNet
+ >>> import torch
+ >>> extra = dict(
+ >>> stage1=dict(
+ >>> num_modules=1,
+ >>> num_branches=1,
+ >>> block='BOTTLENECK',
+ >>> num_blocks=(4, ),
+ >>> num_channels=(64, )),
+ >>> stage2=dict(
+ >>> num_modules=1,
+ >>> num_branches=2,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4),
+ >>> num_channels=(32, 64)),
+ >>> stage3=dict(
+ >>> num_modules=4,
+ >>> num_branches=3,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4),
+ >>> num_channels=(32, 64, 128)),
+ >>> stage4=dict(
+ >>> num_modules=3,
+ >>> num_branches=4,
+ >>> block='BASIC',
+ >>> num_blocks=(4, 4, 4, 4),
+ >>> num_channels=(32, 64, 128, 256)))
+ >>> self = HRNet(extra, in_channels=1)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 1, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 32, 8, 8)
+ (1, 64, 4, 4)
+ (1, 128, 2, 2)
+ (1, 256, 1, 1)
+ """
+
+ blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
+
+ def __init__(self,
+ extra,
+ in_channels=3,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ norm_eval=True,
+ with_cp=False,
+ zero_init_residual=False):
+ super(HRNet, self).__init__()
+ self.extra = extra
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.norm_eval = norm_eval
+ self.with_cp = with_cp
+ self.zero_init_residual = zero_init_residual
+
+ # stem net
+ self.norm1_name, norm1 = build_norm_layer(self.norm_cfg, 64, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(self.norm_cfg, 64, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ 64,
+ 64,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.relu = nn.ReLU(inplace=True)
+
+ # stage 1
+ self.stage1_cfg = self.extra['stage1']
+ num_channels = self.stage1_cfg['num_channels'][0]
+ block_type = self.stage1_cfg['block']
+ num_blocks = self.stage1_cfg['num_blocks'][0]
+
+ block = self.blocks_dict[block_type]
+ stage1_out_channels = num_channels * block.expansion
+ self.layer1 = self._make_layer(block, 64, num_channels, num_blocks)
+
+ # stage 2
+ self.stage2_cfg = self.extra['stage2']
+ num_channels = self.stage2_cfg['num_channels']
+ block_type = self.stage2_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition1 = self._make_transition_layer([stage1_out_channels],
+ num_channels)
+ self.stage2, pre_stage_channels = self._make_stage(
+ self.stage2_cfg, num_channels)
+
+ # stage 3
+ self.stage3_cfg = self.extra['stage3']
+ num_channels = self.stage3_cfg['num_channels']
+ block_type = self.stage3_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition2 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage3, pre_stage_channels = self._make_stage(
+ self.stage3_cfg, num_channels)
+
+ # stage 4
+ self.stage4_cfg = self.extra['stage4']
+ num_channels = self.stage4_cfg['num_channels']
+ block_type = self.stage4_cfg['block']
+
+ block = self.blocks_dict[block_type]
+ num_channels = [channel * block.expansion for channel in num_channels]
+ self.transition3 = self._make_transition_layer(pre_stage_channels,
+ num_channels)
+ self.stage4, pre_stage_channels = self._make_stage(
+ self.stage4_cfg, num_channels)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: the normalization layer named "norm2" """
+ return getattr(self, self.norm2_name)
+
+ def _make_transition_layer(self, num_channels_pre_layer,
+ num_channels_cur_layer):
+ num_branches_cur = len(num_channels_cur_layer)
+ num_branches_pre = len(num_channels_pre_layer)
+
+ transition_layers = []
+ for i in range(num_branches_cur):
+ if i < num_branches_pre:
+ if num_channels_cur_layer[i] != num_channels_pre_layer[i]:
+ transition_layers.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ num_channels_pre_layer[i],
+ num_channels_cur_layer[i],
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg,
+ num_channels_cur_layer[i])[1],
+ nn.ReLU(inplace=True)))
+ else:
+ transition_layers.append(None)
+ else:
+ conv_downsamples = []
+ for j in range(i + 1 - num_branches_pre):
+ in_channels = num_channels_pre_layer[-1]
+ out_channels = num_channels_cur_layer[i] \
+ if j == i - num_branches_pre else in_channels
+ conv_downsamples.append(
+ nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, out_channels)[1],
+ nn.ReLU(inplace=True)))
+ transition_layers.append(nn.Sequential(*conv_downsamples))
+
+ return nn.ModuleList(transition_layers)
+
+ def _make_layer(self, block, inplanes, planes, blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = []
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ stride,
+ downsample=downsample,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+ inplanes = planes * block.expansion
+ for i in range(1, blocks):
+ layers.append(
+ block(
+ inplanes,
+ planes,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*layers)
+
+ def _make_stage(self, layer_config, in_channels, multiscale_output=True):
+ num_modules = layer_config['num_modules']
+ num_branches = layer_config['num_branches']
+ num_blocks = layer_config['num_blocks']
+ num_channels = layer_config['num_channels']
+ block = self.blocks_dict[layer_config['block']]
+
+ hr_modules = []
+ for i in range(num_modules):
+ # multi_scale_output is only used for the last module
+ if not multiscale_output and i == num_modules - 1:
+ reset_multiscale_output = False
+ else:
+ reset_multiscale_output = True
+
+ hr_modules.append(
+ HRModule(
+ num_branches,
+ block,
+ num_blocks,
+ in_channels,
+ num_channels,
+ reset_multiscale_output,
+ with_cp=self.with_cp,
+ norm_cfg=self.norm_cfg,
+ conv_cfg=self.conv_cfg))
+
+ return nn.Sequential(*hr_modules), in_channels
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.conv2(x)
+ x = self.norm2(x)
+ x = self.relu(x)
+ x = self.layer1(x)
+
+ x_list = []
+ for i in range(self.stage2_cfg['num_branches']):
+ if self.transition1[i] is not None:
+ x_list.append(self.transition1[i](x))
+ else:
+ x_list.append(x)
+ y_list = self.stage2(x_list)
+
+ x_list = []
+ for i in range(self.stage3_cfg['num_branches']):
+ if self.transition2[i] is not None:
+ x_list.append(self.transition2[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage3(x_list)
+
+ x_list = []
+ for i in range(self.stage4_cfg['num_branches']):
+ if self.transition3[i] is not None:
+ x_list.append(self.transition3[i](y_list[-1]))
+ else:
+ x_list.append(y_list[i])
+ y_list = self.stage4(x_list)
+
+ return y_list
+
+ def train(self, mode=True):
+ """Convert the model into training mode will keeping the normalization
+ layer freezed."""
+ super(HRNet, self).train(mode)
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
diff --git a/mmdet/models/backbones/regnet.py b/mmdet/models/backbones/regnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..91a602a952226cebb5fd0e3e282c6f98ae4fa455
--- /dev/null
+++ b/mmdet/models/backbones/regnet.py
@@ -0,0 +1,325 @@
+import numpy as np
+import torch.nn as nn
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from .resnet import ResNet
+from .resnext import Bottleneck
+
+
+@BACKBONES.register_module()
+class RegNet(ResNet):
+ """RegNet backbone.
+
+ More details can be found in `paper `_ .
+
+ Args:
+ arch (dict): The parameter of RegNets.
+
+ - w0 (int): initial width
+ - wa (float): slope of width
+ - wm (float): quantization parameter to quantize the width
+ - depth (int): depth of the backbone
+ - group_w (int): width of group
+ - bot_mul (float): bottleneck ratio, i.e. expansion of bottleneck.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ base_channels (int): Base channels after stem layer.
+ in_channels (int): Number of input image channels. Default: 3.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from mmdet.models import RegNet
+ >>> import torch
+ >>> self = RegNet(
+ arch=dict(
+ w0=88,
+ wa=26.31,
+ wm=2.25,
+ group_w=48,
+ depth=25,
+ bot_mul=1.0))
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 96, 8, 8)
+ (1, 192, 4, 4)
+ (1, 432, 2, 2)
+ (1, 1008, 1, 1)
+ """
+ arch_settings = {
+ 'regnetx_400mf':
+ dict(w0=24, wa=24.48, wm=2.54, group_w=16, depth=22, bot_mul=1.0),
+ 'regnetx_800mf':
+ dict(w0=56, wa=35.73, wm=2.28, group_w=16, depth=16, bot_mul=1.0),
+ 'regnetx_1.6gf':
+ dict(w0=80, wa=34.01, wm=2.25, group_w=24, depth=18, bot_mul=1.0),
+ 'regnetx_3.2gf':
+ dict(w0=88, wa=26.31, wm=2.25, group_w=48, depth=25, bot_mul=1.0),
+ 'regnetx_4.0gf':
+ dict(w0=96, wa=38.65, wm=2.43, group_w=40, depth=23, bot_mul=1.0),
+ 'regnetx_6.4gf':
+ dict(w0=184, wa=60.83, wm=2.07, group_w=56, depth=17, bot_mul=1.0),
+ 'regnetx_8.0gf':
+ dict(w0=80, wa=49.56, wm=2.88, group_w=120, depth=23, bot_mul=1.0),
+ 'regnetx_12gf':
+ dict(w0=168, wa=73.36, wm=2.37, group_w=112, depth=19, bot_mul=1.0),
+ }
+
+ def __init__(self,
+ arch,
+ in_channels=3,
+ stem_channels=32,
+ base_channels=32,
+ strides=(2, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+
+ # Generate RegNet parameters first
+ if isinstance(arch, str):
+ assert arch in self.arch_settings, \
+ f'"arch": "{arch}" is not one of the' \
+ ' arch_settings'
+ arch = self.arch_settings[arch]
+ elif not isinstance(arch, dict):
+ raise ValueError('Expect "arch" to be either a string '
+ f'or a dict, got {type(arch)}')
+
+ widths, num_stages = self.generate_regnet(
+ arch['w0'],
+ arch['wa'],
+ arch['wm'],
+ arch['depth'],
+ )
+ # Convert to per stage format
+ stage_widths, stage_blocks = self.get_stages_from_blocks(widths)
+ # Generate group widths and bot muls
+ group_widths = [arch['group_w'] for _ in range(num_stages)]
+ self.bottleneck_ratio = [arch['bot_mul'] for _ in range(num_stages)]
+ # Adjust the compatibility of stage_widths and group_widths
+ stage_widths, group_widths = self.adjust_width_group(
+ stage_widths, self.bottleneck_ratio, group_widths)
+
+ # Group params by stage
+ self.stage_widths = stage_widths
+ self.group_widths = group_widths
+ self.depth = sum(stage_blocks)
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.zero_init_residual = zero_init_residual
+ self.block = Bottleneck
+ expansion_bak = self.block.expansion
+ self.block.expansion = 1
+ self.stage_blocks = stage_blocks[:num_stages]
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.inplanes = stem_channels
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = self.strides[i]
+ dilation = self.dilations[i]
+ group_width = self.group_widths[i]
+ width = int(round(self.stage_widths[i] * self.bottleneck_ratio[i]))
+ stage_groups = width // group_width
+
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins, i)
+ else:
+ stage_plugins = None
+
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=self.stage_widths[i],
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ groups=stage_groups,
+ base_width=group_width,
+ base_channels=self.stage_widths[i])
+ self.inplanes = self.stage_widths[i]
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = stage_widths[-1]
+ self.block.expansion = expansion_bak
+
+ def _make_stem_layer(self, in_channels, base_channels):
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ base_channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, base_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def generate_regnet(self,
+ initial_width,
+ width_slope,
+ width_parameter,
+ depth,
+ divisor=8):
+ """Generates per block width from RegNet parameters.
+
+ Args:
+ initial_width ([int]): Initial width of the backbone
+ width_slope ([float]): Slope of the quantized linear function
+ width_parameter ([int]): Parameter used to quantize the width.
+ depth ([int]): Depth of the backbone.
+ divisor (int, optional): The divisor of channels. Defaults to 8.
+
+ Returns:
+ list, int: return a list of widths of each stage and the number \
+ of stages
+ """
+ assert width_slope >= 0
+ assert initial_width > 0
+ assert width_parameter > 1
+ assert initial_width % divisor == 0
+ widths_cont = np.arange(depth) * width_slope + initial_width
+ ks = np.round(
+ np.log(widths_cont / initial_width) / np.log(width_parameter))
+ widths = initial_width * np.power(width_parameter, ks)
+ widths = np.round(np.divide(widths, divisor)) * divisor
+ num_stages = len(np.unique(widths))
+ widths, widths_cont = widths.astype(int).tolist(), widths_cont.tolist()
+ return widths, num_stages
+
+ @staticmethod
+ def quantize_float(number, divisor):
+ """Converts a float to closest non-zero int divisible by divisor.
+
+ Args:
+ number (int): Original number to be quantized.
+ divisor (int): Divisor used to quantize the number.
+
+ Returns:
+ int: quantized number that is divisible by devisor.
+ """
+ return int(round(number / divisor) * divisor)
+
+ def adjust_width_group(self, widths, bottleneck_ratio, groups):
+ """Adjusts the compatibility of widths and groups.
+
+ Args:
+ widths (list[int]): Width of each stage.
+ bottleneck_ratio (float): Bottleneck ratio.
+ groups (int): number of groups in each stage
+
+ Returns:
+ tuple(list): The adjusted widths and groups of each stage.
+ """
+ bottleneck_width = [
+ int(w * b) for w, b in zip(widths, bottleneck_ratio)
+ ]
+ groups = [min(g, w_bot) for g, w_bot in zip(groups, bottleneck_width)]
+ bottleneck_width = [
+ self.quantize_float(w_bot, g)
+ for w_bot, g in zip(bottleneck_width, groups)
+ ]
+ widths = [
+ int(w_bot / b)
+ for w_bot, b in zip(bottleneck_width, bottleneck_ratio)
+ ]
+ return widths, groups
+
+ def get_stages_from_blocks(self, widths):
+ """Gets widths/stage_blocks of network at each stage.
+
+ Args:
+ widths (list[int]): Width in each stage.
+
+ Returns:
+ tuple(list): width and depth of each stage
+ """
+ width_diff = [
+ width != width_prev
+ for width, width_prev in zip(widths + [0], [0] + widths)
+ ]
+ stage_widths = [
+ width for width, diff in zip(widths, width_diff[:-1]) if diff
+ ]
+ stage_blocks = np.diff([
+ depth for depth, diff in zip(range(len(width_diff)), width_diff)
+ if diff
+ ]).tolist()
+ return stage_widths, stage_blocks
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
diff --git a/mmdet/models/backbones/res2net.py b/mmdet/models/backbones/res2net.py
new file mode 100644
index 0000000000000000000000000000000000000000..7901b7f2fa29741d72328bdbdbf92fc4d5c5f847
--- /dev/null
+++ b/mmdet/models/backbones/res2net.py
@@ -0,0 +1,351 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
+ kaiming_init)
+from mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottle2neck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ scales=4,
+ base_width=26,
+ base_channels=64,
+ stage_type='normal',
+ **kwargs):
+ """Bottle2neck block for Res2Net.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottle2neck, self).__init__(inplanes, planes, **kwargs)
+ assert scales > 1, 'Res2Net degenerates to ResNet when scales = 1.'
+ width = int(math.floor(self.planes * (base_width / base_channels)))
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width * scales, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width * scales,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+
+ if stage_type == 'stage' and self.conv2_stride != 1:
+ self.pool = nn.AvgPool2d(
+ kernel_size=3, stride=self.conv2_stride, padding=1)
+ convs = []
+ bns = []
+
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ for i in range(scales - 1):
+ convs.append(
+ build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False))
+ bns.append(
+ build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ for i in range(scales - 1):
+ convs.append(
+ build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ bias=False))
+ bns.append(
+ build_norm_layer(self.norm_cfg, width, postfix=i + 1)[1])
+ self.convs = nn.ModuleList(convs)
+ self.bns = nn.ModuleList(bns)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width * scales,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.stage_type = stage_type
+ self.scales = scales
+ self.width = width
+ delattr(self, 'conv2')
+ delattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ spx = torch.split(out, self.width, 1)
+ sp = self.convs[0](spx[0].contiguous())
+ sp = self.relu(self.bns[0](sp))
+ out = sp
+ for i in range(1, self.scales - 1):
+ if self.stage_type == 'stage':
+ sp = spx[i]
+ else:
+ sp = sp + spx[i]
+ sp = self.convs[i](sp.contiguous())
+ sp = self.relu(self.bns[i](sp))
+ out = torch.cat((out, sp), 1)
+
+ if self.stage_type == 'normal' or self.conv2_stride == 1:
+ out = torch.cat((out, spx[self.scales - 1]), 1)
+ elif self.stage_type == 'stage':
+ out = torch.cat((out, self.pool(spx[self.scales - 1])), 1)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Res2Layer(nn.Sequential):
+ """Res2Layer to build Res2Net style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ scales (int): Scales used in Res2Net. Default: 4
+ base_width (int): Basic width of each scale. Default: 26
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=True,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ scales=4,
+ base_width=26,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False),
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=1,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1],
+ )
+
+ layers = []
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ stage_type='stage',
+ **kwargs))
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ scales=scales,
+ base_width=base_width,
+ **kwargs))
+ super(Res2Layer, self).__init__(*layers)
+
+
+@BACKBONES.register_module()
+class Res2Net(ResNet):
+ """Res2Net backbone.
+
+ Args:
+ scales (int): Scales used in Res2Net. Default: 4
+ base_width (int): Basic width of each scale. Default: 26
+ depth (int): Depth of res2net, from {50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Res2net stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottle2neck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from mmdet.models import Res2Net
+ >>> import torch
+ >>> self = Res2Net(depth=50, scales=4, base_width=26)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 256, 8, 8)
+ (1, 512, 4, 4)
+ (1, 1024, 2, 2)
+ (1, 2048, 1, 1)
+ """
+
+ arch_settings = {
+ 50: (Bottle2neck, (3, 4, 6, 3)),
+ 101: (Bottle2neck, (3, 4, 23, 3)),
+ 152: (Bottle2neck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ scales=4,
+ base_width=26,
+ style='pytorch',
+ deep_stem=True,
+ avg_down=True,
+ **kwargs):
+ self.scales = scales
+ self.base_width = base_width
+ super(Res2Net, self).__init__(
+ style='pytorch', deep_stem=True, avg_down=True, **kwargs)
+
+ def make_res_layer(self, **kwargs):
+ return Res2Layer(
+ scales=self.scales,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottle2neck):
+ # dcn in Res2Net bottle2neck is in ModuleList
+ for n in m.convs:
+ if hasattr(n, 'conv_offset'):
+ constant_init(n.conv_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottle2neck):
+ constant_init(m.norm3, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
diff --git a/mmdet/models/backbones/resnest.py b/mmdet/models/backbones/resnest.py
new file mode 100644
index 0000000000000000000000000000000000000000..48e1d8bfa47348a13f0da0b9ecf32354fa270340
--- /dev/null
+++ b/mmdet/models/backbones/resnest.py
@@ -0,0 +1,317 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNetV1d
+
+
+class RSoftmax(nn.Module):
+ """Radix Softmax module in ``SplitAttentionConv2d``.
+
+ Args:
+ radix (int): Radix of input.
+ groups (int): Groups of input.
+ """
+
+ def __init__(self, radix, groups):
+ super().__init__()
+ self.radix = radix
+ self.groups = groups
+
+ def forward(self, x):
+ batch = x.size(0)
+ if self.radix > 1:
+ x = x.view(batch, self.groups, self.radix, -1).transpose(1, 2)
+ x = F.softmax(x, dim=1)
+ x = x.reshape(batch, -1)
+ else:
+ x = torch.sigmoid(x)
+ return x
+
+
+class SplitAttentionConv2d(nn.Module):
+ """Split-Attention Conv2d in ResNeSt.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ channels (int): Number of intermediate channels.
+ kernel_size (int | tuple[int]): Size of the convolution kernel.
+ stride (int | tuple[int]): Stride of the convolution.
+ padding (int | tuple[int]): Zero-padding added to both sides of
+ dilation (int | tuple[int]): Spacing between kernel elements.
+ groups (int): Number of blocked connections from input channels to
+ output channels.
+ groups (int): Same as nn.Conv2d.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels. Default: 4.
+ conv_cfg (dict): Config dict for convolution layer. Default: None,
+ which means using conv2d.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ dcn (dict): Config dict for DCN. Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ radix=2,
+ reduction_factor=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None):
+ super(SplitAttentionConv2d, self).__init__()
+ inter_channels = max(in_channels * radix // reduction_factor, 32)
+ self.radix = radix
+ self.groups = groups
+ self.channels = channels
+ self.with_dcn = dcn is not None
+ self.dcn = dcn
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if self.with_dcn and not fallback_on_stride:
+ assert conv_cfg is None, 'conv_cfg must be None for DCN'
+ conv_cfg = dcn
+ self.conv = build_conv_layer(
+ conv_cfg,
+ in_channels,
+ channels * radix,
+ kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups * radix,
+ bias=False)
+ # To be consistent with original implementation, starting from 0
+ self.norm0_name, norm0 = build_norm_layer(
+ norm_cfg, channels * radix, postfix=0)
+ self.add_module(self.norm0_name, norm0)
+ self.relu = nn.ReLU(inplace=True)
+ self.fc1 = build_conv_layer(
+ None, channels, inter_channels, 1, groups=self.groups)
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, inter_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.fc2 = build_conv_layer(
+ None, inter_channels, channels * radix, 1, groups=self.groups)
+ self.rsoftmax = RSoftmax(radix, groups)
+
+ @property
+ def norm0(self):
+ """nn.Module: the normalization layer named "norm0" """
+ return getattr(self, self.norm0_name)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def forward(self, x):
+ x = self.conv(x)
+ x = self.norm0(x)
+ x = self.relu(x)
+
+ batch, rchannel = x.shape[:2]
+ batch = x.size(0)
+ if self.radix > 1:
+ splits = x.view(batch, self.radix, -1, *x.shape[2:])
+ gap = splits.sum(dim=1)
+ else:
+ gap = x
+ gap = F.adaptive_avg_pool2d(gap, 1)
+ gap = self.fc1(gap)
+
+ gap = self.norm1(gap)
+ gap = self.relu(gap)
+
+ atten = self.fc2(gap)
+ atten = self.rsoftmax(atten).view(batch, -1, 1, 1)
+
+ if self.radix > 1:
+ attens = atten.view(batch, self.radix, -1, *atten.shape[2:])
+ out = torch.sum(attens * splits, dim=1)
+ else:
+ out = atten * x
+ return out.contiguous()
+
+
+class Bottleneck(_Bottleneck):
+ """Bottleneck block for ResNeSt.
+
+ Args:
+ inplane (int): Input planes of this block.
+ planes (int): Middle planes of this block.
+ groups (int): Groups of conv2.
+ base_width (int): Base of width in terms of base channels. Default: 4.
+ base_channels (int): Base of channels for calculating width.
+ Default: 64.
+ radix (int): Radix of SpltAtConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Key word arguments for base class.
+ """
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ """Bottleneck block for ResNeSt."""
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.avg_down_stride = avg_down_stride and self.conv2_stride > 1
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.with_modulated_dcn = False
+ self.conv2 = SplitAttentionConv2d(
+ width,
+ width,
+ kernel_size=3,
+ stride=1 if self.avg_down_stride else self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ radix=radix,
+ reduction_factor=reduction_factor,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=self.dcn)
+ delattr(self, self.norm2_name)
+
+ if self.avg_down_stride:
+ self.avd_layer = nn.AvgPool2d(3, self.conv2_stride, padding=1)
+
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+
+ if self.avg_down_stride:
+ out = self.avd_layer(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNeSt(ResNetV1d):
+ """ResNeSt backbone.
+
+ Args:
+ groups (int): Number of groups of Bottleneck. Default: 1
+ base_width (int): Base width of Bottleneck. Default: 4
+ radix (int): Radix of SplitAttentionConv2d. Default: 2
+ reduction_factor (int): Reduction factor of inter_channels in
+ SplitAttentionConv2d. Default: 4.
+ avg_down_stride (bool): Whether to use average pool for stride in
+ Bottleneck. Default: True.
+ kwargs (dict): Keyword arguments for ResNet.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3)),
+ 200: (Bottleneck, (3, 24, 36, 3))
+ }
+
+ def __init__(self,
+ groups=1,
+ base_width=4,
+ radix=2,
+ reduction_factor=4,
+ avg_down_stride=True,
+ **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ self.radix = radix
+ self.reduction_factor = reduction_factor
+ self.avg_down_stride = avg_down_stride
+ super(ResNeSt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ radix=self.radix,
+ reduction_factor=self.reduction_factor,
+ avg_down_stride=self.avg_down_stride,
+ **kwargs)
diff --git a/mmdet/models/backbones/resnet.py b/mmdet/models/backbones/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..3826815a6d94fdc4c54001d4c186d10ca3380e80
--- /dev/null
+++ b/mmdet/models/backbones/resnet.py
@@ -0,0 +1,663 @@
+import torch.nn as nn
+import torch.utils.checkpoint as cp
+from mmcv.cnn import (build_conv_layer, build_norm_layer, build_plugin_layer,
+ constant_init, kaiming_init)
+from mmcv.runner import load_checkpoint
+from torch.nn.modules.batchnorm import _BatchNorm
+
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+from ..utils import ResLayer
+
+
+class BasicBlock(nn.Module):
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(BasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=False)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+class Bottleneck(nn.Module):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ """Bottleneck block for ResNet.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__()
+ assert style in ['pytorch', 'caffe']
+ assert dcn is None or isinstance(dcn, dict)
+ assert plugins is None or isinstance(plugins, list)
+ if plugins is not None:
+ allowed_position = ['after_conv1', 'after_conv2', 'after_conv3']
+ assert all(p['position'] in allowed_position for p in plugins)
+
+ self.inplanes = inplanes
+ self.planes = planes
+ self.stride = stride
+ self.dilation = dilation
+ self.style = style
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.dcn = dcn
+ self.with_dcn = dcn is not None
+ self.plugins = plugins
+ self.with_plugins = plugins is not None
+
+ if self.with_plugins:
+ # collect plugins for conv1/conv2/conv3
+ self.after_conv1_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv1'
+ ]
+ self.after_conv2_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv2'
+ ]
+ self.after_conv3_plugins = [
+ plugin['cfg'] for plugin in plugins
+ if plugin['position'] == 'after_conv3'
+ ]
+
+ if self.style == 'pytorch':
+ self.conv1_stride = 1
+ self.conv2_stride = stride
+ else:
+ self.conv1_stride = stride
+ self.conv2_stride = 1
+
+ self.norm1_name, norm1 = build_norm_layer(norm_cfg, planes, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(norm_cfg, planes, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ norm_cfg, planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ if self.with_dcn:
+ fallback_on_stride = dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ dcn,
+ planes,
+ planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ conv_cfg,
+ planes,
+ planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+
+ if self.with_plugins:
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ planes, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ planes, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ planes * self.expansion, self.after_conv3_plugins)
+
+ def make_block_plugins(self, in_channels, plugins):
+ """make plugins for block.
+
+ Args:
+ in_channels (int): Input channels of plugin.
+ plugins (list[dict]): List of plugins cfg to build.
+
+ Returns:
+ list[str]: List of the names of plugin.
+ """
+ assert isinstance(plugins, list)
+ plugin_names = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ name, layer = build_plugin_layer(
+ plugin,
+ in_channels=in_channels,
+ postfix=plugin.pop('postfix', ''))
+ assert not hasattr(self, name), f'duplicate plugin {name}'
+ self.add_module(name, layer)
+ plugin_names.append(name)
+ return plugin_names
+
+ def forward_plugin(self, x, plugin_names):
+ out = x
+ for name in plugin_names:
+ out = getattr(self, name)(x)
+ return out
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name)
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name)
+
+ @property
+ def norm3(self):
+ """nn.Module: normalization layer after the third convolution layer"""
+ return getattr(self, self.norm3_name)
+
+ def forward(self, x):
+ """Forward function."""
+
+ def _inner_forward(x):
+ identity = x
+ out = self.conv1(x)
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = self.norm2(out)
+ out = self.relu(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv2_plugin_names)
+
+ out = self.conv3(out)
+ out = self.norm3(out)
+
+ if self.with_plugins:
+ out = self.forward_plugin(out, self.after_conv3_plugin_names)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = self.relu(out)
+
+ return out
+
+
+@BACKBONES.register_module()
+class ResNet(nn.Module):
+ """ResNet backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ stem_channels (int | None): Number of stem channels. If not specified,
+ it will be the same as `base_channels`. Default: None.
+ base_channels (int): Number of base channels of res layer. Default: 64.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ plugins (list[dict]): List of plugins for stages, each dict contains:
+
+ - cfg (dict, required): Cfg dict to build plugin.
+ - position (str, required): Position inside block to insert
+ plugin, options are 'after_conv1', 'after_conv2', 'after_conv3'.
+ - stages (tuple[bool], optional): Stages to apply plugin, length
+ should be same as 'num_stages'.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): Whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+
+ Example:
+ >>> from mmdet.models import ResNet
+ >>> import torch
+ >>> self = ResNet(depth=18)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 32, 32)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 64, 8, 8)
+ (1, 128, 4, 4)
+ (1, 256, 2, 2)
+ (1, 512, 1, 1)
+ """
+
+ arch_settings = {
+ 18: (BasicBlock, (2, 2, 2, 2)),
+ 34: (BasicBlock, (3, 4, 6, 3)),
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self,
+ depth,
+ in_channels=3,
+ stem_channels=None,
+ base_channels=64,
+ num_stages=4,
+ strides=(1, 2, 2, 2),
+ dilations=(1, 1, 1, 1),
+ out_indices=(0, 1, 2, 3),
+ style='pytorch',
+ deep_stem=False,
+ avg_down=False,
+ frozen_stages=-1,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ dcn=None,
+ stage_with_dcn=(False, False, False, False),
+ plugins=None,
+ with_cp=False,
+ zero_init_residual=True):
+ super(ResNet, self).__init__()
+ if depth not in self.arch_settings:
+ raise KeyError(f'invalid depth {depth} for resnet')
+ self.depth = depth
+ if stem_channels is None:
+ stem_channels = base_channels
+ self.stem_channels = stem_channels
+ self.base_channels = base_channels
+ self.num_stages = num_stages
+ assert num_stages >= 1 and num_stages <= 4
+ self.strides = strides
+ self.dilations = dilations
+ assert len(strides) == len(dilations) == num_stages
+ self.out_indices = out_indices
+ assert max(out_indices) < num_stages
+ self.style = style
+ self.deep_stem = deep_stem
+ self.avg_down = avg_down
+ self.frozen_stages = frozen_stages
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.with_cp = with_cp
+ self.norm_eval = norm_eval
+ self.dcn = dcn
+ self.stage_with_dcn = stage_with_dcn
+ if dcn is not None:
+ assert len(stage_with_dcn) == num_stages
+ self.plugins = plugins
+ self.zero_init_residual = zero_init_residual
+ self.block, stage_blocks = self.arch_settings[depth]
+ self.stage_blocks = stage_blocks[:num_stages]
+ self.inplanes = stem_channels
+
+ self._make_stem_layer(in_channels, stem_channels)
+
+ self.res_layers = []
+ for i, num_blocks in enumerate(self.stage_blocks):
+ stride = strides[i]
+ dilation = dilations[i]
+ dcn = self.dcn if self.stage_with_dcn[i] else None
+ if plugins is not None:
+ stage_plugins = self.make_stage_plugins(plugins, i)
+ else:
+ stage_plugins = None
+ planes = base_channels * 2**i
+ res_layer = self.make_res_layer(
+ block=self.block,
+ inplanes=self.inplanes,
+ planes=planes,
+ num_blocks=num_blocks,
+ stride=stride,
+ dilation=dilation,
+ style=self.style,
+ avg_down=self.avg_down,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins)
+ self.inplanes = planes * self.block.expansion
+ layer_name = f'layer{i + 1}'
+ self.add_module(layer_name, res_layer)
+ self.res_layers.append(layer_name)
+
+ self._freeze_stages()
+
+ self.feat_dim = self.block.expansion * base_channels * 2**(
+ len(self.stage_blocks) - 1)
+
+ def make_stage_plugins(self, plugins, stage_idx):
+ """Make plugins for ResNet ``stage_idx`` th stage.
+
+ Currently we support to insert ``context_block``,
+ ``empirical_attention_block``, ``nonlocal_block`` into the backbone
+ like ResNet/ResNeXt. They could be inserted after conv1/conv2/conv3 of
+ Bottleneck.
+
+ An example of plugins format could be:
+
+ Examples:
+ >>> plugins=[
+ ... dict(cfg=dict(type='xxx', arg1='xxx'),
+ ... stages=(False, True, True, True),
+ ... position='after_conv2'),
+ ... dict(cfg=dict(type='yyy'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='1'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3'),
+ ... dict(cfg=dict(type='zzz', postfix='2'),
+ ... stages=(True, True, True, True),
+ ... position='after_conv3')
+ ... ]
+ >>> self = ResNet(depth=18)
+ >>> stage_plugins = self.make_stage_plugins(plugins, 0)
+ >>> assert len(stage_plugins) == 3
+
+ Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
+
+ .. code-block:: none
+
+ conv1-> conv2->conv3->yyy->zzz1->zzz2
+
+ Suppose 'stage_idx=1', the structure of blocks in the stage would be:
+
+ .. code-block:: none
+
+ conv1-> conv2->xxx->conv3->yyy->zzz1->zzz2
+
+ If stages is missing, the plugin would be applied to all stages.
+
+ Args:
+ plugins (list[dict]): List of plugins cfg to build. The postfix is
+ required if multiple same type plugins are inserted.
+ stage_idx (int): Index of stage to build
+
+ Returns:
+ list[dict]: Plugins for current stage
+ """
+ stage_plugins = []
+ for plugin in plugins:
+ plugin = plugin.copy()
+ stages = plugin.pop('stages', None)
+ assert stages is None or len(stages) == self.num_stages
+ # whether to insert plugin into current stage
+ if stages is None or stages[stage_idx]:
+ stage_plugins.append(plugin)
+
+ return stage_plugins
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``."""
+ return ResLayer(**kwargs)
+
+ @property
+ def norm1(self):
+ """nn.Module: the normalization layer named "norm1" """
+ return getattr(self, self.norm1_name)
+
+ def _make_stem_layer(self, in_channels, stem_channels):
+ if self.deep_stem:
+ self.stem = nn.Sequential(
+ build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels // 2,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels // 2)[1],
+ nn.ReLU(inplace=True),
+ build_conv_layer(
+ self.conv_cfg,
+ stem_channels // 2,
+ stem_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ bias=False),
+ build_norm_layer(self.norm_cfg, stem_channels)[1],
+ nn.ReLU(inplace=True))
+ else:
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ in_channels,
+ stem_channels,
+ kernel_size=7,
+ stride=2,
+ padding=3,
+ bias=False)
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, stem_channels, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.relu = nn.ReLU(inplace=True)
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ if self.deep_stem:
+ self.stem.eval()
+ for param in self.stem.parameters():
+ param.requires_grad = False
+ else:
+ self.norm1.eval()
+ for m in [self.conv1, self.norm1]:
+ for param in m.parameters():
+ param.requires_grad = False
+
+ for i in range(1, self.frozen_stages + 1):
+ m = getattr(self, f'layer{i}')
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
+ constant_init(m, 1)
+
+ if self.dcn is not None:
+ for m in self.modules():
+ if isinstance(m, Bottleneck) and hasattr(
+ m.conv2, 'conv_offset'):
+ constant_init(m.conv2.conv_offset, 0)
+
+ if self.zero_init_residual:
+ for m in self.modules():
+ if isinstance(m, Bottleneck):
+ constant_init(m.norm3, 0)
+ elif isinstance(m, BasicBlock):
+ constant_init(m.norm2, 0)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ if self.deep_stem:
+ x = self.stem(x)
+ else:
+ x = self.conv1(x)
+ x = self.norm1(x)
+ x = self.relu(x)
+ x = self.maxpool(x)
+ outs = []
+ for i, layer_name in enumerate(self.res_layers):
+ res_layer = getattr(self, layer_name)
+ x = res_layer(x)
+ if i in self.out_indices:
+ outs.append(x)
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep normalization layer
+ freezed."""
+ super(ResNet, self).train(mode)
+ self._freeze_stages()
+ if mode and self.norm_eval:
+ for m in self.modules():
+ # trick: eval have effect on BatchNorm only
+ if isinstance(m, _BatchNorm):
+ m.eval()
+
+
+@BACKBONES.register_module()
+class ResNetV1d(ResNet):
+ r"""ResNetV1d variant described in `Bag of Tricks
+ `_.
+
+ Compared with default ResNet(ResNetV1b), ResNetV1d replaces the 7x7 conv in
+ the input stem with three 3x3 convs. And in the downsampling block, a 2x2
+ avg_pool with stride 2 is added before conv, whose stride is changed to 1.
+ """
+
+ def __init__(self, **kwargs):
+ super(ResNetV1d, self).__init__(
+ deep_stem=True, avg_down=True, **kwargs)
diff --git a/mmdet/models/backbones/resnext.py b/mmdet/models/backbones/resnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..6dbcbd516fd308b1d703eecb83ab275f6b159516
--- /dev/null
+++ b/mmdet/models/backbones/resnext.py
@@ -0,0 +1,153 @@
+import math
+
+from mmcv.cnn import build_conv_layer, build_norm_layer
+
+from ..builder import BACKBONES
+from ..utils import ResLayer
+from .resnet import Bottleneck as _Bottleneck
+from .resnet import ResNet
+
+
+class Bottleneck(_Bottleneck):
+ expansion = 4
+
+ def __init__(self,
+ inplanes,
+ planes,
+ groups=1,
+ base_width=4,
+ base_channels=64,
+ **kwargs):
+ """Bottleneck block for ResNeXt.
+
+ If style is "pytorch", the stride-two layer is the 3x3 conv layer, if
+ it is "caffe", the stride-two layer is the first 1x1 conv layer.
+ """
+ super(Bottleneck, self).__init__(inplanes, planes, **kwargs)
+
+ if groups == 1:
+ width = self.planes
+ else:
+ width = math.floor(self.planes *
+ (base_width / base_channels)) * groups
+
+ self.norm1_name, norm1 = build_norm_layer(
+ self.norm_cfg, width, postfix=1)
+ self.norm2_name, norm2 = build_norm_layer(
+ self.norm_cfg, width, postfix=2)
+ self.norm3_name, norm3 = build_norm_layer(
+ self.norm_cfg, self.planes * self.expansion, postfix=3)
+
+ self.conv1 = build_conv_layer(
+ self.conv_cfg,
+ self.inplanes,
+ width,
+ kernel_size=1,
+ stride=self.conv1_stride,
+ bias=False)
+ self.add_module(self.norm1_name, norm1)
+ fallback_on_stride = False
+ self.with_modulated_dcn = False
+ if self.with_dcn:
+ fallback_on_stride = self.dcn.pop('fallback_on_stride', False)
+ if not self.with_dcn or fallback_on_stride:
+ self.conv2 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+ else:
+ assert self.conv_cfg is None, 'conv_cfg must be None for DCN'
+ self.conv2 = build_conv_layer(
+ self.dcn,
+ width,
+ width,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ padding=self.dilation,
+ dilation=self.dilation,
+ groups=groups,
+ bias=False)
+
+ self.add_module(self.norm2_name, norm2)
+ self.conv3 = build_conv_layer(
+ self.conv_cfg,
+ width,
+ self.planes * self.expansion,
+ kernel_size=1,
+ bias=False)
+ self.add_module(self.norm3_name, norm3)
+
+ if self.with_plugins:
+ self._del_block_plugins(self.after_conv1_plugin_names +
+ self.after_conv2_plugin_names +
+ self.after_conv3_plugin_names)
+ self.after_conv1_plugin_names = self.make_block_plugins(
+ width, self.after_conv1_plugins)
+ self.after_conv2_plugin_names = self.make_block_plugins(
+ width, self.after_conv2_plugins)
+ self.after_conv3_plugin_names = self.make_block_plugins(
+ self.planes * self.expansion, self.after_conv3_plugins)
+
+ def _del_block_plugins(self, plugin_names):
+ """delete plugins for block if exist.
+
+ Args:
+ plugin_names (list[str]): List of plugins name to delete.
+ """
+ assert isinstance(plugin_names, list)
+ for plugin_name in plugin_names:
+ del self._modules[plugin_name]
+
+
+@BACKBONES.register_module()
+class ResNeXt(ResNet):
+ """ResNeXt backbone.
+
+ Args:
+ depth (int): Depth of resnet, from {18, 34, 50, 101, 152}.
+ in_channels (int): Number of input image channels. Default: 3.
+ num_stages (int): Resnet stages. Default: 4.
+ groups (int): Group of resnext.
+ base_width (int): Base width of resnext.
+ strides (Sequence[int]): Strides of the first block of each stage.
+ dilations (Sequence[int]): Dilation of each stage.
+ out_indices (Sequence[int]): Output from which stages.
+ style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two
+ layer is the 3x3 conv layer, otherwise the stride-two layer is
+ the first 1x1 conv layer.
+ frozen_stages (int): Stages to be frozen (all param fixed). -1 means
+ not freezing any parameters.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ norm_eval (bool): Whether to set norm layers to eval mode, namely,
+ freeze running stats (mean and var). Note: Effect on Batch Norm
+ and its variants only.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ zero_init_residual (bool): whether to use zero init for last norm layer
+ in resblocks to let them behave as identity.
+ """
+
+ arch_settings = {
+ 50: (Bottleneck, (3, 4, 6, 3)),
+ 101: (Bottleneck, (3, 4, 23, 3)),
+ 152: (Bottleneck, (3, 8, 36, 3))
+ }
+
+ def __init__(self, groups=1, base_width=4, **kwargs):
+ self.groups = groups
+ self.base_width = base_width
+ super(ResNeXt, self).__init__(**kwargs)
+
+ def make_res_layer(self, **kwargs):
+ """Pack all blocks in a stage into a ``ResLayer``"""
+ return ResLayer(
+ groups=self.groups,
+ base_width=self.base_width,
+ base_channels=self.base_channels,
+ **kwargs)
diff --git a/mmdet/models/backbones/ssd_vgg.py b/mmdet/models/backbones/ssd_vgg.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbc4fbb2301afc002f47abb9ed133a500d6cf23f
--- /dev/null
+++ b/mmdet/models/backbones/ssd_vgg.py
@@ -0,0 +1,169 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import VGG, constant_init, kaiming_init, normal_init, xavier_init
+from mmcv.runner import load_checkpoint
+
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+@BACKBONES.register_module()
+class SSDVGG(VGG):
+ """VGG Backbone network for single-shot-detection.
+
+ Args:
+ input_size (int): width and height of input, from {300, 512}.
+ depth (int): Depth of vgg, from {11, 13, 16, 19}.
+ out_indices (Sequence[int]): Output from which stages.
+
+ Example:
+ >>> self = SSDVGG(input_size=300, depth=11)
+ >>> self.eval()
+ >>> inputs = torch.rand(1, 3, 300, 300)
+ >>> level_outputs = self.forward(inputs)
+ >>> for level_out in level_outputs:
+ ... print(tuple(level_out.shape))
+ (1, 1024, 19, 19)
+ (1, 512, 10, 10)
+ (1, 256, 5, 5)
+ (1, 256, 3, 3)
+ (1, 256, 1, 1)
+ """
+ extra_setting = {
+ 300: (256, 'S', 512, 128, 'S', 256, 128, 256, 128, 256),
+ 512: (256, 'S', 512, 128, 'S', 256, 128, 'S', 256, 128, 'S', 256, 128),
+ }
+
+ def __init__(self,
+ input_size,
+ depth,
+ with_last_pool=False,
+ ceil_mode=True,
+ out_indices=(3, 4),
+ out_feature_indices=(22, 34),
+ l2_norm_scale=20.):
+ # TODO: in_channels for mmcv.VGG
+ super(SSDVGG, self).__init__(
+ depth,
+ with_last_pool=with_last_pool,
+ ceil_mode=ceil_mode,
+ out_indices=out_indices)
+ assert input_size in (300, 512)
+ self.input_size = input_size
+
+ self.features.add_module(
+ str(len(self.features)),
+ nn.MaxPool2d(kernel_size=3, stride=1, padding=1))
+ self.features.add_module(
+ str(len(self.features)),
+ nn.Conv2d(512, 1024, kernel_size=3, padding=6, dilation=6))
+ self.features.add_module(
+ str(len(self.features)), nn.ReLU(inplace=True))
+ self.features.add_module(
+ str(len(self.features)), nn.Conv2d(1024, 1024, kernel_size=1))
+ self.features.add_module(
+ str(len(self.features)), nn.ReLU(inplace=True))
+ self.out_feature_indices = out_feature_indices
+
+ self.inplanes = 1024
+ self.extra = self._make_extra_layers(self.extra_setting[input_size])
+ self.l2_norm = L2Norm(
+ self.features[out_feature_indices[0] - 1].out_channels,
+ l2_norm_scale)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.features.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ elif isinstance(m, nn.Linear):
+ normal_init(m, std=0.01)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ for m in self.extra.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ constant_init(self.l2_norm, self.l2_norm.scale)
+
+ def forward(self, x):
+ """Forward function."""
+ outs = []
+ for i, layer in enumerate(self.features):
+ x = layer(x)
+ if i in self.out_feature_indices:
+ outs.append(x)
+ for i, layer in enumerate(self.extra):
+ x = F.relu(layer(x), inplace=True)
+ if i % 2 == 1:
+ outs.append(x)
+ outs[0] = self.l2_norm(outs[0])
+ if len(outs) == 1:
+ return outs[0]
+ else:
+ return tuple(outs)
+
+ def _make_extra_layers(self, outplanes):
+ layers = []
+ kernel_sizes = (1, 3)
+ num_layers = 0
+ outplane = None
+ for i in range(len(outplanes)):
+ if self.inplanes == 'S':
+ self.inplanes = outplane
+ continue
+ k = kernel_sizes[num_layers % 2]
+ if outplanes[i] == 'S':
+ outplane = outplanes[i + 1]
+ conv = nn.Conv2d(
+ self.inplanes, outplane, k, stride=2, padding=1)
+ else:
+ outplane = outplanes[i]
+ conv = nn.Conv2d(
+ self.inplanes, outplane, k, stride=1, padding=0)
+ layers.append(conv)
+ self.inplanes = outplanes[i]
+ num_layers += 1
+ if self.input_size == 512:
+ layers.append(nn.Conv2d(self.inplanes, 256, 4, padding=1))
+
+ return nn.Sequential(*layers)
+
+
+class L2Norm(nn.Module):
+
+ def __init__(self, n_dims, scale=20., eps=1e-10):
+ """L2 normalization layer.
+
+ Args:
+ n_dims (int): Number of dimensions to be normalized
+ scale (float, optional): Defaults to 20..
+ eps (float, optional): Used to avoid division by zero.
+ Defaults to 1e-10.
+ """
+ super(L2Norm, self).__init__()
+ self.n_dims = n_dims
+ self.weight = nn.Parameter(torch.Tensor(self.n_dims))
+ self.eps = eps
+ self.scale = scale
+
+ def forward(self, x):
+ """Forward function."""
+ # normalization layer convert to FP32 in FP16 training
+ x_float = x.float()
+ norm = x_float.pow(2).sum(1, keepdim=True).sqrt() + self.eps
+ return (self.weight[None, :, None, None].float().expand_as(x_float) *
+ x_float / norm).type_as(x)
diff --git a/mmdet/models/backbones/swin_transformer.py b/mmdet/models/backbones/swin_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb41850d8480a08a6a7698bf6129ffd1ab239681
--- /dev/null
+++ b/mmdet/models/backbones/swin_transformer.py
@@ -0,0 +1,630 @@
+# --------------------------------------------------------
+# Swin Transformer
+# Copyright (c) 2021 Microsoft
+# Licensed under The MIT License [see LICENSE for details]
+# Written by Ze Liu, Yutong Lin, Yixuan Wei
+# --------------------------------------------------------
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+import numpy as np
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from mmcv_custom import load_checkpoint
+from mmdet.utils import get_root_logger
+from ..builder import BACKBONES
+
+
+class Mlp(nn.Module):
+ """ Multilayer perceptron."""
+
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
+ windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
+ x = windows.view(B, H // window_size, W // window_size, window_size, window_size, -1)
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ """ Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ """
+
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = qk_scale or head_dim ** -0.5
+
+ # define a parameter table of relative position bias
+ self.relative_position_bias_table = nn.Parameter(
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :] # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+
+ trunc_normal_(self.relative_position_bias_table, std=.02)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """ Forward function.
+
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
+ q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
+
+ q = q * self.scale
+ attn = (q @ k.transpose(-2, -1))
+
+ relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+ attn = self.softmax(attn)
+ else:
+ attn = self.softmax(attn)
+
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class SwinTransformerBlock(nn.Module):
+ """ Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, num_heads, window_size=7, shift_size=0,
+ mlp_ratio=4., qkv_bias=True, qk_scale=None, drop=0., attn_drop=0., drop_path=0.,
+ act_layer=nn.GELU, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+ assert 0 <= self.shift_size < self.window_size, "shift_size must in 0-window_size"
+
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim, window_size=to_2tuple(self.window_size), num_heads=num_heads,
+ qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
+
+ self.H = None
+ self.W = None
+
+ def forward(self, x, mask_matrix):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ mask_matrix: Attention mask for cyclic shift.
+ """
+ B, L, C = x.shape
+ H, W = self.H, self.W
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = self.norm1(x)
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size - W % self.window_size) % self.window_size
+ pad_b = (self.window_size - H % self.window_size) % self.window_size
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size > 0:
+ shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(shifted_x, self.window_size) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(-1, self.window_size * self.window_size, C) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(x_windows, mask=attn_mask) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size > 0:
+ x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ # FFN
+ x = shortcut + self.drop_path(x)
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
+
+ return x
+
+
+class PatchMerging(nn.Module):
+ """ Patch Merging Layer
+
+ Args:
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(4 * dim)
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ x = x.view(B, H, W, C)
+
+ # padding
+ pad_input = (H % 2 == 1) or (W % 2 == 1)
+ if pad_input:
+ x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.norm(x)
+ x = self.reduction(x)
+
+ return x
+
+
+class BasicLayer(nn.Module):
+ """ A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of feature channels
+ depth (int): Depths of this stage.
+ num_heads (int): Number of attention head.
+ window_size (int): Local window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ dim,
+ depth,
+ num_heads,
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.,
+ attn_drop=0.,
+ drop_path=0.,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False):
+ super().__init__()
+ self.window_size = window_size
+ self.shift_size = window_size // 2
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+
+ # build blocks
+ self.blocks = nn.ModuleList([
+ SwinTransformerBlock(
+ dim=dim,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer)
+ for i in range(depth)])
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ """ Forward function.
+
+ Args:
+ x: Input feature, tensor size (B, H*W, C).
+ H, W: Spatial resolution of the input feature.
+ """
+
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size)) * self.window_size
+ Wp = int(np.ceil(W / self.window_size)) * self.window_size
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ w_slices = (slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None))
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x, H, W, x_down, Wh, Ww
+ else:
+ return x, H, W, x, H, W
+
+
+class PatchEmbed(nn.Module):
+ """ Image to Patch Embedding
+
+ Args:
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(self, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None):
+ super().__init__()
+ patch_size = to_2tuple(patch_size)
+ self.patch_size = patch_size
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2)
+ x = self.norm(x)
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+
+@BACKBONES.register_module()
+class SwinTransformer(nn.Module):
+ """ Swin Transformer backbone.
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ pretrain_img_size (int): Input image size for training the pretrained model,
+ used in absolute postion embedding. Default 224.
+ patch_size (int | tuple(int)): Patch size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ depths (tuple[int]): Depths of each Swin Transformer stage.
+ num_heads (tuple[int]): Number of attention head of each stage.
+ window_size (int): Window size. Default: 7.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4.
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set.
+ drop_rate (float): Dropout rate.
+ attn_drop_rate (float): Attention dropout rate. Default: 0.
+ drop_path_rate (float): Stochastic depth rate. Default: 0.2.
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False.
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True.
+ out_indices (Sequence[int]): Output from which stages.
+ frozen_stages (int): Stages to be frozen (stop grad and set eval mode).
+ -1 means not freezing any parameters.
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ """
+
+ def __init__(self,
+ pretrain_img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.,
+ attn_drop_rate=0.,
+ drop_path_rate=0.2,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ out_indices=(0, 1, 2, 3),
+ frozen_stages=-1,
+ use_checkpoint=False):
+ super().__init__()
+
+ self.pretrain_img_size = pretrain_img_size
+ self.num_layers = len(depths)
+ self.embed_dim = embed_dim
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.out_indices = out_indices
+ self.frozen_stages = frozen_stages
+
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None)
+
+ # absolute position embedding
+ if self.ape:
+ pretrain_img_size = to_2tuple(pretrain_img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [pretrain_img_size[0] // patch_size[0], pretrain_img_size[1] // patch_size[1]]
+
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, embed_dim, patches_resolution[0], patches_resolution[1]))
+ trunc_normal_(self.absolute_pos_embed, std=.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2 ** i_layer),
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint)
+ self.layers.append(layer)
+
+ num_features = [int(embed_dim * 2 ** i) for i in range(self.num_layers)]
+ self.num_features = num_features
+
+ # add a norm layer for each output
+ for i_layer in out_indices:
+ layer = norm_layer(num_features[i_layer])
+ layer_name = f'norm{i_layer}'
+ self.add_module(layer_name, layer)
+
+ self._freeze_stages()
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+
+ if self.frozen_stages >= 1 and self.ape:
+ self.absolute_pos_embed.requires_grad = False
+
+ if self.frozen_stages >= 2:
+ self.pos_drop.eval()
+ for i in range(0, self.frozen_stages - 1):
+ m = self.layers[i]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in backbone.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+
+ def _init_weights(m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ if isinstance(pretrained, str):
+ self.apply(_init_weights)
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ self.apply(_init_weights)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ def forward(self, x):
+ """Forward function."""
+ x = self.patch_embed(x)
+
+ Wh, Ww = x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(self.absolute_pos_embed, size=(Wh, Ww), mode='bicubic')
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ outs = []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_out, H, W, x, Wh, Ww = layer(x, Wh, Ww)
+
+ if i in self.out_indices:
+ norm_layer = getattr(self, f'norm{i}')
+ x_out = norm_layer(x_out)
+
+ out = x_out.view(-1, H, W, self.num_features[i]).permute(0, 3, 1, 2).contiguous()
+ outs.append(out)
+
+ return tuple(outs)
+
+ def train(self, mode=True):
+ """Convert the model into training mode while keep layers freezed."""
+ super(SwinTransformer, self).train(mode)
+ self._freeze_stages()
diff --git a/mmdet/models/backbones/trident_resnet.py b/mmdet/models/backbones/trident_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6100132b0f4120585da8a309cba4488b4b0ea72
--- /dev/null
+++ b/mmdet/models/backbones/trident_resnet.py
@@ -0,0 +1,292 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as cp
+from mmcv.cnn import build_conv_layer, build_norm_layer, kaiming_init
+from torch.nn.modules.utils import _pair
+
+from mmdet.models.backbones.resnet import Bottleneck, ResNet
+from mmdet.models.builder import BACKBONES
+
+
+class TridentConv(nn.Module):
+ """Trident Convolution Module.
+
+ Args:
+ in_channels (int): Number of channels in input.
+ out_channels (int): Number of channels in output.
+ kernel_size (int): Size of convolution kernel.
+ stride (int, optional): Convolution stride. Default: 1.
+ trident_dilations (tuple[int, int, int], optional): Dilations of
+ different trident branch. Default: (1, 2, 3).
+ test_branch_idx (int, optional): In inference, all 3 branches will
+ be used if `test_branch_idx==-1`, otherwise only branch with
+ index `test_branch_idx` will be used. Default: 1.
+ bias (bool, optional): Whether to use bias in convolution or not.
+ Default: False.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ trident_dilations=(1, 2, 3),
+ test_branch_idx=1,
+ bias=False):
+ super(TridentConv, self).__init__()
+ self.num_branch = len(trident_dilations)
+ self.with_bias = bias
+ self.test_branch_idx = test_branch_idx
+ self.stride = _pair(stride)
+ self.kernel_size = _pair(kernel_size)
+ self.paddings = _pair(trident_dilations)
+ self.dilations = trident_dilations
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.bias = bias
+
+ self.weight = nn.Parameter(
+ torch.Tensor(out_channels, in_channels, *self.kernel_size))
+ if bias:
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
+ else:
+ self.bias = None
+ self.init_weights()
+
+ def init_weights(self):
+ kaiming_init(self, distribution='uniform', mode='fan_in')
+
+ def extra_repr(self):
+ tmpstr = f'in_channels={self.in_channels}'
+ tmpstr += f', out_channels={self.out_channels}'
+ tmpstr += f', kernel_size={self.kernel_size}'
+ tmpstr += f', num_branch={self.num_branch}'
+ tmpstr += f', test_branch_idx={self.test_branch_idx}'
+ tmpstr += f', stride={self.stride}'
+ tmpstr += f', paddings={self.paddings}'
+ tmpstr += f', dilations={self.dilations}'
+ tmpstr += f', bias={self.bias}'
+ return tmpstr
+
+ def forward(self, inputs):
+ if self.training or self.test_branch_idx == -1:
+ outputs = [
+ F.conv2d(input, self.weight, self.bias, self.stride, padding,
+ dilation) for input, dilation, padding in zip(
+ inputs, self.dilations, self.paddings)
+ ]
+ else:
+ assert len(inputs) == 1
+ outputs = [
+ F.conv2d(inputs[0], self.weight, self.bias, self.stride,
+ self.paddings[self.test_branch_idx],
+ self.dilations[self.test_branch_idx])
+ ]
+
+ return outputs
+
+
+# Since TridentNet is defined over ResNet50 and ResNet101, here we
+# only support TridentBottleneckBlock.
+class TridentBottleneck(Bottleneck):
+ """BottleBlock for TridentResNet.
+
+ Args:
+ trident_dilations (tuple[int, int, int]): Dilations of different
+ trident branch.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ concat_output (bool): Whether to concat the output list to a Tensor.
+ `True` only in the last Block.
+ """
+
+ def __init__(self, trident_dilations, test_branch_idx, concat_output,
+ **kwargs):
+
+ super(TridentBottleneck, self).__init__(**kwargs)
+ self.trident_dilations = trident_dilations
+ self.num_branch = len(trident_dilations)
+ self.concat_output = concat_output
+ self.test_branch_idx = test_branch_idx
+ self.conv2 = TridentConv(
+ self.planes,
+ self.planes,
+ kernel_size=3,
+ stride=self.conv2_stride,
+ bias=False,
+ trident_dilations=self.trident_dilations,
+ test_branch_idx=test_branch_idx)
+
+ def forward(self, x):
+
+ def _inner_forward(x):
+ num_branch = (
+ self.num_branch
+ if self.training or self.test_branch_idx == -1 else 1)
+ identity = x
+ if not isinstance(x, list):
+ x = (x, ) * num_branch
+ identity = x
+ if self.downsample is not None:
+ identity = [self.downsample(b) for b in x]
+
+ out = [self.conv1(b) for b in x]
+ out = [self.norm1(b) for b in out]
+ out = [self.relu(b) for b in out]
+
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv1_plugin_names)
+
+ out = self.conv2(out)
+ out = [self.norm2(b) for b in out]
+ out = [self.relu(b) for b in out]
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv2_plugin_names)
+
+ out = [self.conv3(b) for b in out]
+ out = [self.norm3(b) for b in out]
+
+ if self.with_plugins:
+ for k in range(len(out)):
+ out[k] = self.forward_plugin(out[k],
+ self.after_conv3_plugin_names)
+
+ out = [
+ out_b + identity_b for out_b, identity_b in zip(out, identity)
+ ]
+ return out
+
+ if self.with_cp and x.requires_grad:
+ out = cp.checkpoint(_inner_forward, x)
+ else:
+ out = _inner_forward(x)
+
+ out = [self.relu(b) for b in out]
+ if self.concat_output:
+ out = torch.cat(out, dim=0)
+ return out
+
+
+def make_trident_res_layer(block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ trident_dilations=(1, 2, 3),
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None,
+ test_branch_idx=-1):
+ """Build Trident Res Layers."""
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ for i in range(num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride if i == 0 else 1,
+ trident_dilations=trident_dilations,
+ downsample=downsample if i == 0 else None,
+ style=style,
+ with_cp=with_cp,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ dcn=dcn,
+ plugins=plugins,
+ test_branch_idx=test_branch_idx,
+ concat_output=True if i == num_blocks - 1 else False))
+ inplanes = planes * block.expansion
+ return nn.Sequential(*layers)
+
+
+@BACKBONES.register_module()
+class TridentResNet(ResNet):
+ """The stem layer, stage 1 and stage 2 in Trident ResNet are identical to
+ ResNet, while in stage 3, Trident BottleBlock is utilized to replace the
+ normal BottleBlock to yield trident output. Different branch shares the
+ convolution weight but uses different dilations to achieve multi-scale
+ output.
+
+ / stage3(b0) \
+ x - stem - stage1 - stage2 - stage3(b1) - output
+ \ stage3(b2) /
+
+ Args:
+ depth (int): Depth of resnet, from {50, 101, 152}.
+ num_branch (int): Number of branches in TridentNet.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ trident_dilations (tuple[int]): Dilations of different trident branch.
+ len(trident_dilations) should be equal to num_branch.
+ """ # noqa
+
+ def __init__(self, depth, num_branch, test_branch_idx, trident_dilations,
+ **kwargs):
+
+ assert num_branch == len(trident_dilations)
+ assert depth in (50, 101, 152)
+ super(TridentResNet, self).__init__(depth, **kwargs)
+ assert self.num_stages == 3
+ self.test_branch_idx = test_branch_idx
+ self.num_branch = num_branch
+
+ last_stage_idx = self.num_stages - 1
+ stride = self.strides[last_stage_idx]
+ dilation = trident_dilations
+ dcn = self.dcn if self.stage_with_dcn[last_stage_idx] else None
+ if self.plugins is not None:
+ stage_plugins = self.make_stage_plugins(self.plugins,
+ last_stage_idx)
+ else:
+ stage_plugins = None
+ planes = self.base_channels * 2**last_stage_idx
+ res_layer = make_trident_res_layer(
+ TridentBottleneck,
+ inplanes=(self.block.expansion * self.base_channels *
+ 2**(last_stage_idx - 1)),
+ planes=planes,
+ num_blocks=self.stage_blocks[last_stage_idx],
+ stride=stride,
+ trident_dilations=dilation,
+ style=self.style,
+ with_cp=self.with_cp,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn,
+ plugins=stage_plugins,
+ test_branch_idx=self.test_branch_idx)
+
+ layer_name = f'layer{last_stage_idx + 1}'
+
+ self.__setattr__(layer_name, res_layer)
+ self.res_layers.pop(last_stage_idx)
+ self.res_layers.insert(last_stage_idx, layer_name)
+
+ self._freeze_stages()
diff --git a/mmdet/models/builder.py b/mmdet/models/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..81c927e507a7c1625ffb114de10e93c94927af25
--- /dev/null
+++ b/mmdet/models/builder.py
@@ -0,0 +1,77 @@
+import warnings
+
+from mmcv.utils import Registry, build_from_cfg
+from torch import nn
+
+BACKBONES = Registry('backbone')
+NECKS = Registry('neck')
+ROI_EXTRACTORS = Registry('roi_extractor')
+SHARED_HEADS = Registry('shared_head')
+HEADS = Registry('head')
+LOSSES = Registry('loss')
+DETECTORS = Registry('detector')
+
+
+def build(cfg, registry, default_args=None):
+ """Build a module.
+
+ Args:
+ cfg (dict, list[dict]): The config of modules, is is either a dict
+ or a list of configs.
+ registry (:obj:`Registry`): A registry the module belongs to.
+ default_args (dict, optional): Default arguments to build the module.
+ Defaults to None.
+
+ Returns:
+ nn.Module: A built nn module.
+ """
+ if isinstance(cfg, list):
+ modules = [
+ build_from_cfg(cfg_, registry, default_args) for cfg_ in cfg
+ ]
+ return nn.Sequential(*modules)
+ else:
+ return build_from_cfg(cfg, registry, default_args)
+
+
+def build_backbone(cfg):
+ """Build backbone."""
+ return build(cfg, BACKBONES)
+
+
+def build_neck(cfg):
+ """Build neck."""
+ return build(cfg, NECKS)
+
+
+def build_roi_extractor(cfg):
+ """Build roi extractor."""
+ return build(cfg, ROI_EXTRACTORS)
+
+
+def build_shared_head(cfg):
+ """Build shared head."""
+ return build(cfg, SHARED_HEADS)
+
+
+def build_head(cfg):
+ """Build head."""
+ return build(cfg, HEADS)
+
+
+def build_loss(cfg):
+ """Build loss."""
+ return build(cfg, LOSSES)
+
+
+def build_detector(cfg, train_cfg=None, test_cfg=None):
+ """Build detector."""
+ if train_cfg is not None or test_cfg is not None:
+ warnings.warn(
+ 'train_cfg and test_cfg is deprecated, '
+ 'please specify them in model', UserWarning)
+ assert cfg.get('train_cfg') is None or train_cfg is None, \
+ 'train_cfg specified in both outer field and model field '
+ assert cfg.get('test_cfg') is None or test_cfg is None, \
+ 'test_cfg specified in both outer field and model field '
+ return build(cfg, DETECTORS, dict(train_cfg=train_cfg, test_cfg=test_cfg))
diff --git a/mmdet/models/dense_heads/__init__.py b/mmdet/models/dense_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f004dd95d97df16167f932587b3ce73b05b04a37
--- /dev/null
+++ b/mmdet/models/dense_heads/__init__.py
@@ -0,0 +1,41 @@
+from .anchor_free_head import AnchorFreeHead
+from .anchor_head import AnchorHead
+from .atss_head import ATSSHead
+from .cascade_rpn_head import CascadeRPNHead, StageCascadeRPNHead
+from .centripetal_head import CentripetalHead
+from .corner_head import CornerHead
+from .embedding_rpn_head import EmbeddingRPNHead
+from .fcos_head import FCOSHead
+from .fovea_head import FoveaHead
+from .free_anchor_retina_head import FreeAnchorRetinaHead
+from .fsaf_head import FSAFHead
+from .ga_retina_head import GARetinaHead
+from .ga_rpn_head import GARPNHead
+from .gfl_head import GFLHead
+from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
+from .ld_head import LDHead
+from .nasfcos_head import NASFCOSHead
+from .paa_head import PAAHead
+from .pisa_retinanet_head import PISARetinaHead
+from .pisa_ssd_head import PISASSDHead
+from .reppoints_head import RepPointsHead
+from .retina_head import RetinaHead
+from .retina_sepbn_head import RetinaSepBNHead
+from .rpn_head import RPNHead
+from .sabl_retina_head import SABLRetinaHead
+from .ssd_head import SSDHead
+from .transformer_head import TransformerHead
+from .vfnet_head import VFNetHead
+from .yolact_head import YOLACTHead, YOLACTProtonet, YOLACTSegmHead
+from .yolo_head import YOLOV3Head
+
+__all__ = [
+ 'AnchorFreeHead', 'AnchorHead', 'GuidedAnchorHead', 'FeatureAdaption',
+ 'RPNHead', 'GARPNHead', 'RetinaHead', 'RetinaSepBNHead', 'GARetinaHead',
+ 'SSDHead', 'FCOSHead', 'RepPointsHead', 'FoveaHead',
+ 'FreeAnchorRetinaHead', 'ATSSHead', 'FSAFHead', 'NASFCOSHead',
+ 'PISARetinaHead', 'PISASSDHead', 'GFLHead', 'CornerHead', 'YOLACTHead',
+ 'YOLACTSegmHead', 'YOLACTProtonet', 'YOLOV3Head', 'PAAHead',
+ 'SABLRetinaHead', 'CentripetalHead', 'VFNetHead', 'TransformerHead',
+ 'StageCascadeRPNHead', 'CascadeRPNHead', 'EmbeddingRPNHead', 'LDHead'
+]
diff --git a/mmdet/models/dense_heads/anchor_free_head.py b/mmdet/models/dense_heads/anchor_free_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1814a0cc4f577f470f74f025440073a0aaa1ebd0
--- /dev/null
+++ b/mmdet/models/dense_heads/anchor_free_head.py
@@ -0,0 +1,340 @@
+from abc import abstractmethod
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class AnchorFreeHead(BaseDenseHead, BBoxTestMixin):
+ """Anchor-free head (FCOS, Fovea, RepPoints, etc.).
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ stacked_convs (int): Number of stacking convs of the head.
+ strides (tuple): Downsample factor of each feature map.
+ dcn_on_last_conv (bool): If true, use dcn in the last layer of
+ towers. Default: False.
+ conv_bias (bool | str): If specified as `auto`, it will be decided by
+ the norm_cfg. Bias of conv will be set as True if `norm_cfg` is
+ None, otherwise False. Default: "auto".
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ """ # noqa: W605
+
+ _version = 1
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ stacked_convs=4,
+ strides=(4, 8, 16, 32, 64),
+ dcn_on_last_conv=False,
+ conv_bias='auto',
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoULoss', loss_weight=1.0),
+ conv_cfg=None,
+ norm_cfg=None,
+ train_cfg=None,
+ test_cfg=None):
+ super(AnchorFreeHead, self).__init__()
+ self.num_classes = num_classes
+ self.cls_out_channels = num_classes
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.stacked_convs = stacked_convs
+ self.strides = strides
+ self.dcn_on_last_conv = dcn_on_last_conv
+ assert conv_bias == 'auto' or isinstance(conv_bias, bool)
+ self.conv_bias = conv_bias
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self._init_cls_convs()
+ self._init_reg_convs()
+ self._init_predictor()
+
+ def _init_cls_convs(self):
+ """Initialize classification conv layers of the head."""
+ self.cls_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_reg_convs(self):
+ """Initialize bbox regression conv layers of the head."""
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ if self.dcn_on_last_conv and i == self.stacked_convs - 1:
+ conv_cfg = dict(type='DCNv2')
+ else:
+ conv_cfg = self.conv_cfg
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias))
+
+ def _init_predictor(self):
+ """Initialize predictor layers of the head."""
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ if isinstance(m.conv, nn.Conv2d):
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ if isinstance(m.conv, nn.Conv2d):
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.conv_cls, std=0.01, bias=bias_cls)
+ normal_init(self.conv_reg, std=0.01)
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Hack some keys of the model state dict so that can load checkpoints
+ of previous version."""
+ version = local_metadata.get('version', None)
+ if version is None:
+ # the key is different in early versions
+ # for example, 'fcos_cls' become 'conv_cls' now
+ bbox_head_keys = [
+ k for k in state_dict.keys() if k.startswith(prefix)
+ ]
+ ori_predictor_keys = []
+ new_predictor_keys = []
+ # e.g. 'fcos_cls' or 'fcos_reg'
+ for key in bbox_head_keys:
+ ori_predictor_keys.append(key)
+ key = key.split('.')
+ conv_name = None
+ if key[1].endswith('cls'):
+ conv_name = 'conv_cls'
+ elif key[1].endswith('reg'):
+ conv_name = 'conv_reg'
+ elif key[1].endswith('centerness'):
+ conv_name = 'conv_centerness'
+ else:
+ assert NotImplementedError
+ if conv_name is not None:
+ key[1] = conv_name
+ new_predictor_keys.append('.'.join(key))
+ else:
+ ori_predictor_keys.pop(-1)
+ for i in range(len(new_predictor_keys)):
+ state_dict[new_predictor_keys[i]] = state_dict.pop(
+ ori_predictor_keys[i])
+ super()._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys, unexpected_keys,
+ error_msgs)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually contain classification scores and bbox predictions.
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ """
+ return multi_apply(self.forward_single, feats)[:2]
+
+ def forward_single(self, x):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+
+ Returns:
+ tuple: Scores for each class, bbox predictions, features
+ after classification and regression conv layers, some
+ models needs these features like FCOS.
+ """
+ cls_feat = x
+ reg_feat = x
+
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.conv_cls(cls_feat)
+
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+ bbox_pred = self.conv_reg(reg_feat)
+ return cls_score, bbox_pred, cls_feat, reg_feat
+
+ @abstractmethod
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ """
+
+ raise NotImplementedError
+
+ @abstractmethod
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg=None,
+ rescale=None):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_points * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W)
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space
+ """
+
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute regression, classification and centerness targets for points
+ in multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ """
+ raise NotImplementedError
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points of a single scale level."""
+ h, w = featmap_size
+ x_range = torch.arange(w, dtype=dtype, device=device)
+ y_range = torch.arange(h, dtype=dtype, device=device)
+ y, x = torch.meshgrid(y_range, x_range)
+ if flatten:
+ y = y.flatten()
+ x = x.flatten()
+ return y, x
+
+ def get_points(self, featmap_sizes, dtype, device, flatten=False):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ dtype (torch.dtype): Type of points.
+ device (torch.device): Device of points.
+
+ Returns:
+ tuple: points of each image.
+ """
+ mlvl_points = []
+ for i in range(len(featmap_sizes)):
+ mlvl_points.append(
+ self._get_points_single(featmap_sizes[i], self.strides[i],
+ dtype, device, flatten))
+ return mlvl_points
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
diff --git a/mmdet/models/dense_heads/anchor_head.py b/mmdet/models/dense_heads/anchor_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..eea73520572725f547216ab639c1ebbdfb50834c
--- /dev/null
+++ b/mmdet/models/dense_heads/anchor_head.py
@@ -0,0 +1,751 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_anchor_generator,
+ build_assigner, build_bbox_coder, build_sampler,
+ images_to_levels, multi_apply, multiclass_nms, unmap)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class AnchorHead(BaseDenseHead, BBoxTestMixin):
+ """Anchor-based head (RPN, RetinaNet, SSD, etc.).
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels. Used in child classes.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8, 16, 32],
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=(.0, .0, .0, .0),
+ target_stds=(1.0, 1.0, 1.0, 1.0)),
+ reg_decoded_bbox=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None):
+ super(AnchorHead, self).__init__()
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ # TODO better way to determine whether sample or not
+ self.sampling = loss_cls['type'] not in [
+ 'FocalLoss', 'GHMC', 'QualityFocalLoss'
+ ]
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+
+ if self.cls_out_channels <= 0:
+ raise ValueError(f'num_classes={num_classes} is too small')
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.fp16_enabled = False
+
+ self.anchor_generator = build_anchor_generator(anchor_generator)
+ # usually the numbers of anchors for each level are the same
+ # except SSD detectors
+ self.num_anchors = self.anchor_generator.num_base_anchors[0]
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.conv_cls = nn.Conv2d(self.in_channels,
+ self.num_anchors * self.cls_out_channels, 1)
+ self.conv_reg = nn.Conv2d(self.in_channels, self.num_anchors * 4, 1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ normal_init(self.conv_cls, std=0.01)
+ normal_init(self.conv_reg, std=0.01)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale \
+ level, the channels number is num_anchors * 4.
+ """
+ cls_score = self.conv_cls(x)
+ bbox_pred = self.conv_reg(x)
+ return cls_score, bbox_pred
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: A tuple of classification scores and bbox prediction.
+
+ - cls_scores (list[Tensor]): Classification scores for all \
+ scale levels, each is a 4D-tensor, the channels number \
+ is num_anchors * num_classes.
+ - bbox_preds (list[Tensor]): Box energies / deltas for all \
+ scale levels, each is a 4D-tensor, the channels number \
+ is num_anchors * 4.
+ """
+ return multi_apply(self.forward_single, feats)
+
+ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get anchors according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): Device for returned tensors
+
+ Returns:
+ tuple:
+ anchor_list (list[Tensor]): Anchors of each image.
+ valid_flag_list (list[Tensor]): Valid flags of each image.
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # anchors for one time
+ multi_level_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device)
+ anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level anchors
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = self.anchor_generator.valid_flags(
+ featmap_sizes, img_meta['pad_shape'], device)
+ valid_flag_list.append(multi_level_flags)
+
+ return anchor_list, valid_flag_list
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors ,4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ img_meta (dict): Meta info of the image.
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level
+ label_weights_list (list[Tensor]): Label weights of each level
+ bbox_targets_list (list[Tensor]): BBox targets of each level
+ bbox_weights_list (list[Tensor]): BBox weights of each level
+ num_total_pos (int): Number of positive samples in all images
+ num_total_neg (int): Number of negative samples in all images
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags,
+ fill=self.num_classes) # fill bg label
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ return_sampling_results=False):
+ """Compute regression and classification targets for anchors in
+ multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_targets_list (list[Tensor]): BBox targets of each level.
+ - bbox_weights_list (list[Tensor]): BBox weights of each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ additional_returns: This function enables user-defined returns from
+ `self._get_targets_single`. These returns are currently refined
+ to properties at each feature map (i.e. having HxW dimension).
+ The results will be concatenated after the end
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors to a single tensor
+ concat_anchor_list = []
+ concat_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ concat_anchor_list,
+ concat_valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
+ pos_inds_list, neg_inds_list, sampling_results_list) = results[:7]
+ rest_results = list(results[7:]) # user-added return values
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ res = (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+ if return_sampling_results:
+ res = res + (sampling_results_list, )
+ for i, r in enumerate(rest_results): # user-added return values
+ rest_results[i] = images_to_levels(r, num_level_anchors)
+
+ return res + tuple(rest_results)
+
+ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor wight
+ shape (N, num_total_anchors, 4).
+ bbox_weights (Tensor): BBox regression loss weights of each anchor
+ with shape (N, num_total_anchors, 4).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_cls, loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each level in the
+ feature pyramid, has shape
+ (N, num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each
+ level in the feature pyramid, has shape
+ (N, num_anchors * 4, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+
+ Example:
+ >>> import mmcv
+ >>> self = AnchorHead(
+ >>> num_classes=9,
+ >>> in_channels=1,
+ >>> anchor_generator=dict(
+ >>> type='AnchorGenerator',
+ >>> scales=[8],
+ >>> ratios=[0.5, 1.0, 2.0],
+ >>> strides=[4,]))
+ >>> img_metas = [{'img_shape': (32, 32, 3), 'scale_factor': 1}]
+ >>> cfg = mmcv.Config(dict(
+ >>> score_thr=0.00,
+ >>> nms=dict(type='nms', iou_thr=1.0),
+ >>> max_per_img=10))
+ >>> feat = torch.rand(1, 1, 3, 3)
+ >>> cls_score, bbox_pred = self.forward_single(feat)
+ >>> # note the input lists are over different levels, not images
+ >>> cls_scores, bbox_preds = [cls_score], [bbox_pred]
+ >>> result_list = self.get_bboxes(cls_scores, bbox_preds,
+ >>> img_metas, cfg)
+ >>> det_bboxes, det_labels = result_list[0]
+ >>> assert len(result_list) == 1
+ >>> assert det_bboxes.shape[1] == 5
+ >>> assert len(det_bboxes) == len(det_labels) == cfg.max_per_img
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+
+ mlvl_cls_scores = [cls_scores[i].detach() for i in range(num_levels)]
+ mlvl_bbox_preds = [bbox_preds[i].detach() for i in range(num_levels)]
+
+ if torch.onnx.is_in_onnx_export():
+ assert len(
+ img_metas
+ ) == 1, 'Only support one input image while in exporting to ONNX'
+ img_shapes = img_metas[0]['img_shape_for_onnx']
+ else:
+ img_shapes = [
+ img_metas[i]['img_shape']
+ for i in range(cls_scores[0].shape[0])
+ ]
+ scale_factors = [
+ img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0])
+ ]
+
+ if with_nms:
+ # some heads don't support with_nms argument
+ result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
+ mlvl_anchors, img_shapes,
+ scale_factors, cfg, rescale)
+ else:
+ result_list = self._get_bboxes(mlvl_cls_scores, mlvl_bbox_preds,
+ mlvl_anchors, img_shapes,
+ scale_factors, cfg, rescale,
+ with_nms)
+ return result_list
+
+ def _get_bboxes(self,
+ mlvl_cls_scores,
+ mlvl_bbox_preds,
+ mlvl_anchors,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a batch item into bbox predictions.
+
+ Args:
+ mlvl_cls_scores (list[Tensor]): Each element in the list is
+ the scores of bboxes of single level in the feature pyramid,
+ has shape (N, num_anchors * num_classes, H, W).
+ mlvl_bbox_preds (list[Tensor]): Each element in the list is the
+ bboxes predictions of single level in the feature pyramid,
+ has shape (N, num_anchors * 4, H, W).
+ mlvl_anchors (list[Tensor]): Each element in the list is
+ the anchors of single level in feature pyramid, has shape
+ (num_anchors, 4).
+ img_shapes (list[tuple[int]]): Each tuple in the list represent
+ the shape(height, width, 3) of single image in the batch.
+ scale_factors (list[ndarray]): Scale factor of the batch
+ image arange as list[(w_scale, h_scale, w_scale, h_scale)].
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(mlvl_cls_scores) == len(mlvl_bbox_preds) == len(
+ mlvl_anchors)
+ batch_size = mlvl_cls_scores[0].shape[0]
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1),
+ device=mlvl_cls_scores[0].device,
+ dtype=torch.long)
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, anchors in zip(mlvl_cls_scores,
+ mlvl_bbox_preds,
+ mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(batch_size, -1,
+ self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(batch_size, -1, 4)
+ anchors = anchors.expand_as(bbox_pred)
+ # Always keep topk op for dynamic input in onnx
+ if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
+ or scores.shape[-2] > nms_pre_tensor):
+ from torch import _shape_as_tensor
+ # keep shape as tensor and get k
+ num_anchor = _shape_as_tensor(scores)[-2].to(
+ nms_pre_tensor.device)
+ nms_pre = torch.where(nms_pre_tensor < num_anchor,
+ nms_pre_tensor, num_anchor)
+
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(-1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[..., :-1].max(-1)
+
+ _, topk_inds = max_scores.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds)
+ anchors = anchors[batch_inds, topk_inds, :]
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shapes)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+
+ # Set max number of box to be feed into nms in deployment
+ deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
+ if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = batch_mlvl_scores.max(-1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = batch_mlvl_scores[..., :-1].max(-1)
+ _, topk_inds = max_scores.topk(deploy_nms_pre)
+ batch_inds = torch.arange(batch_size).view(-1,
+ 1).expand_as(topk_inds)
+ batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds]
+ batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds]
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1],
+ 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+
+ if with_nms:
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
+ batch_mlvl_scores):
+ det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ det_results.append(tuple([det_bbox, det_label]))
+ else:
+ det_results = [
+ tuple(mlvl_bs)
+ for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores)
+ ]
+ return det_results
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
diff --git a/mmdet/models/dense_heads/atss_head.py b/mmdet/models/dense_heads/atss_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff55dfa1790ba270539fc9f623dbb2984fa1a99e
--- /dev/null
+++ b/mmdet/models/dense_heads/atss_head.py
@@ -0,0 +1,689 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_assigner, build_sampler,
+ images_to_levels, multi_apply, multiclass_nms,
+ reduce_mean, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+EPS = 1e-12
+
+
+@HEADS.register_module()
+class ATSSHead(AnchorHead):
+ """Bridging the Gap Between Anchor-based and Anchor-free Detection via
+ Adaptive Training Sample Selection.
+
+ ATSS head structure is similar with FCOS, however ATSS use anchor boxes
+ and assign label by Adaptive Training Sample Selection instead max-iou.
+
+ https://arxiv.org/abs/1912.02424
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_centerness=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(ATSSHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.loss_centerness = build_loss(loss_centerness)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.atss_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.atss_reg = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1)
+ self.atss_centerness = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 1, 3, padding=1)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.anchor_generator.strides])
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.atss_cls, std=0.01, bias=bias_cls)
+ normal_init(self.atss_reg, std=0.01)
+ normal_init(self.atss_centerness, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale
+ level, the channels number is num_anchors * 4.
+ centerness (Tensor): Centerness for a single scale level, the
+ channel number is (N, num_anchors * 1, H, W).
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.atss_cls(cls_feat)
+ # we just follow atss, not apply exp in bbox_pred
+ bbox_pred = scale(self.atss_reg(reg_feat)).float()
+ centerness = self.atss_centerness(reg_feat)
+ return cls_score, bbox_pred, centerness
+
+ def loss_single(self, anchors, cls_score, bbox_pred, centerness, labels,
+ label_weights, bbox_targets, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor wight
+ shape (N, num_total_anchors, 4).
+ num_total_samples (int): Number os positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3, 1).reshape(
+ -1, self.cls_out_channels).contiguous()
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ centerness = centerness.permute(0, 2, 3, 1).reshape(-1)
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # classification loss
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_centerness = centerness[pos_inds]
+
+ centerness_targets = self.centerness_target(
+ pos_anchors, pos_bbox_targets)
+ pos_decode_bbox_pred = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_pred)
+ pos_decode_bbox_targets = self.bbox_coder.decode(
+ pos_anchors, pos_bbox_targets)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=centerness_targets,
+ avg_factor=1.0)
+
+ # centerness loss
+ loss_centerness = self.loss_centerness(
+ pos_centerness,
+ centerness_targets,
+ avg_factor=num_total_samples)
+
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_centerness = centerness.sum() * 0
+ centerness_targets = bbox_targets.new_tensor(0.)
+
+ return loss_cls, loss_bbox, loss_centerness, centerness_targets.sum()
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ centernesses (list[Tensor]): Centerness for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, loss_centerness,\
+ bbox_avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ num_total_samples=num_total_samples)
+
+ bbox_avg_factor = sum(bbox_avg_factor)
+ bbox_avg_factor = reduce_mean(bbox_avg_factor).item()
+ if bbox_avg_factor < EPS:
+ bbox_avg_factor = 1
+ losses_bbox = list(map(lambda x: x / bbox_avg_factor, losses_bbox))
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_centerness=loss_centerness)
+
+ def centerness_target(self, anchors, bbox_targets):
+ # only calculate pos centerness targets, otherwise there may be nan
+ gts = self.bbox_coder.decode(anchors, bbox_targets)
+ anchors_cx = (anchors[:, 2] + anchors[:, 0]) / 2
+ anchors_cy = (anchors[:, 3] + anchors[:, 1]) / 2
+ l_ = anchors_cx - gts[:, 0]
+ t_ = anchors_cy - gts[:, 1]
+ r_ = gts[:, 2] - anchors_cx
+ b_ = gts[:, 3] - anchors_cy
+
+ left_right = torch.stack([l_, r_], dim=1)
+ top_bottom = torch.stack([t_, b_], dim=1)
+ centerness = torch.sqrt(
+ (left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) *
+ (top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0]))
+ assert not torch.isnan(centerness).any()
+ return centerness
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ centernesses (list[Tensor]): Centerness for each scale level with
+ shape (N, num_anchors * 1, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+
+ cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
+ bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]
+ centerness_pred_list = [
+ centernesses[i].detach() for i in range(num_levels)
+ ]
+ img_shapes = [
+ img_metas[i]['img_shape'] for i in range(cls_scores[0].shape[0])
+ ]
+ scale_factors = [
+ img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0])
+ ]
+ result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
+ centerness_pred_list, mlvl_anchors,
+ img_shapes, scale_factors, cfg, rescale,
+ with_nms)
+ return result_list
+
+ def _get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ mlvl_anchors,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into labeled boxes.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for a single scale level
+ with shape (N, num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for a single
+ scale level with shape (N, num_anchors * 4, H, W).
+ centernesses (list[Tensor]): Centerness for a single scale level
+ with shape (N, num_anchors * 1, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_anchors, 4).
+ img_shapes (list[tuple[int]]): Shape of the input image,
+ list[(height, width, 3)].
+ scale_factors (list[ndarray]): Scale factor of the image arrange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+ device = cls_scores[0].device
+ batch_size = cls_scores[0].shape[0]
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_centerness = []
+ for cls_score, bbox_pred, centerness, anchors in zip(
+ cls_scores, bbox_preds, centernesses, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ scores = cls_score.permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.cls_out_channels).sigmoid()
+ centerness = centerness.permute(0, 2, 3,
+ 1).reshape(batch_size,
+ -1).sigmoid()
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(batch_size, -1, 4)
+
+ # Always keep topk op for dynamic input in onnx
+ if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
+ or scores.shape[-2] > nms_pre_tensor):
+ from torch import _shape_as_tensor
+ # keep shape as tensor and get k
+ num_anchor = _shape_as_tensor(scores)[-2].to(device)
+ nms_pre = torch.where(nms_pre_tensor < num_anchor,
+ nms_pre_tensor, num_anchor)
+
+ max_scores, _ = (scores * centerness[..., None]).max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+ centerness = centerness[batch_inds, topk_inds]
+ else:
+ anchors = anchors.expand_as(bbox_pred)
+
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shapes)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_centerness.append(centerness)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
+
+ # Set max number of box to be feed into nms in deployment
+ deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
+ if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
+ batch_mlvl_scores, _ = (
+ batch_mlvl_scores *
+ batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores)
+ ).max(-1)
+ _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre)
+ batch_inds = torch.arange(batch_size).view(-1,
+ 1).expand_as(topk_inds)
+ batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :]
+ batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :]
+ batch_mlvl_centerness = batch_mlvl_centerness[batch_inds,
+ topk_inds]
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1], 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+
+ if with_nms:
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores,
+ mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_centerness):
+ det_bbox, det_label = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=mlvl_centerness)
+ det_results.append(tuple([det_bbox, det_label]))
+ else:
+ det_results = [
+ tuple(mlvl_bs)
+ for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_centerness)
+ ]
+ return det_results
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get targets for ATSS head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. Besides
+ returning the targets as the parent method does, it also returns the
+ anchors as the first element of the returned tuple.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list[i] = torch.cat(anchor_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors)
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors ,4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ num_level_anchors Tensor): Number of anchors of each scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ labels (Tensor): Labels of all anchors in the image with shape
+ (N,).
+ label_weights (Tensor): Label weights of all anchor in the
+ image with shape (N,).
+ bbox_targets (Tensor): BBox targets of all anchors in the
+ image with shape (N, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4)
+ pos_inds (Tensor): Indices of positive anchor with shape
+ (num_pos,).
+ neg_inds (Tensor): Indices of negative anchor with shape
+ (num_neg,).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore,
+ gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if hasattr(self, 'bbox_coder'):
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ # used in VFNetHead
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class since v2.5.0
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/dense_heads/base_dense_head.py b/mmdet/models/dense_heads/base_dense_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..de11e4a2197b1dfe241ce7a66daa1907a8fc5661
--- /dev/null
+++ b/mmdet/models/dense_heads/base_dense_head.py
@@ -0,0 +1,59 @@
+from abc import ABCMeta, abstractmethod
+
+import torch.nn as nn
+
+
+class BaseDenseHead(nn.Module, metaclass=ABCMeta):
+ """Base class for DenseHeads."""
+
+ def __init__(self):
+ super(BaseDenseHead, self).__init__()
+
+ @abstractmethod
+ def loss(self, **kwargs):
+ """Compute losses of the head."""
+ pass
+
+ @abstractmethod
+ def get_bboxes(self, **kwargs):
+ """Transform network output for a batch into bbox predictions."""
+ pass
+
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor]): Features from FPN.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used
+
+ Returns:
+ tuple:
+ losses: (dict[str, Tensor]): A dictionary of loss components.
+ proposal_list (list[Tensor]): Proposals of each image.
+ """
+ outs = self(x)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
+ return losses, proposal_list
diff --git a/mmdet/models/dense_heads/cascade_rpn_head.py b/mmdet/models/dense_heads/cascade_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e32ee461951e685fb44a461033293159e3439717
--- /dev/null
+++ b/mmdet/models/dense_heads/cascade_rpn_head.py
@@ -0,0 +1,784 @@
+from __future__ import division
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+from mmcv import ConfigDict
+from mmcv.cnn import normal_init
+from mmcv.ops import DeformConv2d, batched_nms
+
+from mmdet.core import (RegionAssigner, build_assigner, build_sampler,
+ images_to_levels, multi_apply)
+from ..builder import HEADS, build_head
+from .base_dense_head import BaseDenseHead
+from .rpn_head import RPNHead
+
+
+class AdaptiveConv(nn.Module):
+ """AdaptiveConv used to adapt the sampling location with the anchors.
+
+ Args:
+ in_channels (int): Number of channels in the input image
+ out_channels (int): Number of channels produced by the convolution
+ kernel_size (int or tuple): Size of the conv kernel. Default: 3
+ stride (int or tuple, optional): Stride of the convolution. Default: 1
+ padding (int or tuple, optional): Zero-padding added to both sides of
+ the input. Default: 1
+ dilation (int or tuple, optional): Spacing between kernel elements.
+ Default: 3
+ groups (int, optional): Number of blocked connections from input
+ channels to output channels. Default: 1
+ bias (bool, optional): If set True, adds a learnable bias to the
+ output. Default: False.
+ type (str, optional): Type of adaptive conv, can be either 'offset'
+ (arbitrary anchors) or 'dilation' (uniform anchor).
+ Default: 'dilation'.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ stride=1,
+ padding=1,
+ dilation=3,
+ groups=1,
+ bias=False,
+ type='dilation'):
+ super(AdaptiveConv, self).__init__()
+ assert type in ['offset', 'dilation']
+ self.adapt_type = type
+
+ assert kernel_size == 3, 'Adaptive conv only supports kernels 3'
+ if self.adapt_type == 'offset':
+ assert stride == 1 and padding == 1 and groups == 1, \
+ 'Adaptive conv offset mode only supports padding: {1}, ' \
+ f'stride: {1}, groups: {1}'
+ self.conv = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=padding,
+ stride=stride,
+ groups=groups,
+ bias=bias)
+ else:
+ self.conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=dilation,
+ dilation=dilation)
+
+ def init_weights(self):
+ """Init weights."""
+ normal_init(self.conv, std=0.01)
+
+ def forward(self, x, offset):
+ """Forward function."""
+ if self.adapt_type == 'offset':
+ N, _, H, W = x.shape
+ assert offset is not None
+ assert H * W == offset.shape[1]
+ # reshape [N, NA, 18] to (N, 18, H, W)
+ offset = offset.permute(0, 2, 1).reshape(N, -1, H, W)
+ offset = offset.contiguous()
+ x = self.conv(x, offset)
+ else:
+ assert offset is None
+ x = self.conv(x)
+ return x
+
+
+@HEADS.register_module()
+class StageCascadeRPNHead(RPNHead):
+ """Stage of CascadeRPNHead.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ anchor_generator (dict): anchor generator config.
+ adapt_cfg (dict): adaptation config.
+ bridged_feature (bool, optional): whether update rpn feature.
+ Default: False.
+ with_cls (bool, optional): wheather use classification branch.
+ Default: True.
+ sampling (bool, optional): wheather use sampling. Default: True.
+ """
+
+ def __init__(self,
+ in_channels,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ scales=[8],
+ ratios=[1.0],
+ strides=[4, 8, 16, 32, 64]),
+ adapt_cfg=dict(type='dilation', dilation=3),
+ bridged_feature=False,
+ with_cls=True,
+ sampling=True,
+ **kwargs):
+ self.with_cls = with_cls
+ self.anchor_strides = anchor_generator['strides']
+ self.anchor_scales = anchor_generator['scales']
+ self.bridged_feature = bridged_feature
+ self.adapt_cfg = adapt_cfg
+ super(StageCascadeRPNHead, self).__init__(
+ in_channels, anchor_generator=anchor_generator, **kwargs)
+
+ # override sampling and sampler
+ self.sampling = sampling
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ def _init_layers(self):
+ """Init layers of a CascadeRPN stage."""
+ self.rpn_conv = AdaptiveConv(self.in_channels, self.feat_channels,
+ **self.adapt_cfg)
+ if self.with_cls:
+ self.rpn_cls = nn.Conv2d(self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 1)
+ self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def init_weights(self):
+ """Init weights of a CascadeRPN stage."""
+ self.rpn_conv.init_weights()
+ normal_init(self.rpn_reg, std=0.01)
+ if self.with_cls:
+ normal_init(self.rpn_cls, std=0.01)
+
+ def forward_single(self, x, offset):
+ """Forward function of single scale."""
+ bridged_x = x
+ x = self.relu(self.rpn_conv(x, offset))
+ if self.bridged_feature:
+ bridged_x = x # update feature
+ cls_score = self.rpn_cls(x) if self.with_cls else None
+ bbox_pred = self.rpn_reg(x)
+ return bridged_x, cls_score, bbox_pred
+
+ def forward(self, feats, offset_list=None):
+ """Forward function."""
+ if offset_list is None:
+ offset_list = [None for _ in range(len(feats))]
+ return multi_apply(self.forward_single, feats, offset_list)
+
+ def _region_targets_single(self,
+ anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ featmap_sizes,
+ label_channels=1):
+ """Get anchor targets based on region for single level."""
+ assign_result = self.assigner.assign(
+ anchors,
+ valid_flags,
+ gt_bboxes,
+ img_meta,
+ featmap_sizes,
+ self.anchor_scales[0],
+ self.anchor_strides,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ gt_labels=None,
+ allowed_border=self.train_cfg.allowed_border)
+ flat_anchors = torch.cat(anchors)
+ sampling_result = self.sampler.sample(assign_result, flat_anchors,
+ gt_bboxes)
+
+ num_anchors = flat_anchors.shape[0]
+ bbox_targets = torch.zeros_like(flat_anchors)
+ bbox_weights = torch.zeros_like(flat_anchors)
+ labels = flat_anchors.new_zeros(num_anchors, dtype=torch.long)
+ label_weights = flat_anchors.new_zeros(num_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ labels[pos_inds] = 1
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds)
+
+ def region_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """See :func:`StageCascadeRPNHead.get_targets`."""
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_targets, all_bbox_weights,
+ pos_inds_list, neg_inds_list) = multi_apply(
+ self._region_targets_single,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ featmap_sizes=featmap_sizes,
+ label_channels=label_channels)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore=None,
+ label_channels=1):
+ """Compute regression and classification targets for anchors.
+
+ Args:
+ anchor_list (list[list]): Multi level anchors of each image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ featmap_sizes (list[Tensor]): Feature mapsize each level
+ gt_bboxes_ignore (list[Tensor]): Ignore bboxes of each images
+ label_channels (int): Channel of label.
+
+ Returns:
+ cls_reg_targets (tuple)
+ """
+ if isinstance(self.assigner, RegionAssigner):
+ cls_reg_targets = self.region_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ label_channels=label_channels)
+ else:
+ cls_reg_targets = super(StageCascadeRPNHead, self).get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ label_channels=label_channels)
+ return cls_reg_targets
+
+ def anchor_offset(self, anchor_list, anchor_strides, featmap_sizes):
+ """ Get offest for deformable conv based on anchor shape
+ NOTE: currently support deformable kernel_size=3 and dilation=1
+
+ Args:
+ anchor_list (list[list[tensor])): [NI, NLVL, NA, 4] list of
+ multi-level anchors
+ anchor_strides (list[int]): anchor stride of each level
+
+ Returns:
+ offset_list (list[tensor]): [NLVL, NA, 2, 18]: offset of DeformConv
+ kernel.
+ """
+
+ def _shape_offset(anchors, stride, ks=3, dilation=1):
+ # currently support kernel_size=3 and dilation=1
+ assert ks == 3 and dilation == 1
+ pad = (ks - 1) // 2
+ idx = torch.arange(-pad, pad + 1, dtype=dtype, device=device)
+ yy, xx = torch.meshgrid(idx, idx) # return order matters
+ xx = xx.reshape(-1)
+ yy = yy.reshape(-1)
+ w = (anchors[:, 2] - anchors[:, 0]) / stride
+ h = (anchors[:, 3] - anchors[:, 1]) / stride
+ w = w / (ks - 1) - dilation
+ h = h / (ks - 1) - dilation
+ offset_x = w[:, None] * xx # (NA, ks**2)
+ offset_y = h[:, None] * yy # (NA, ks**2)
+ return offset_x, offset_y
+
+ def _ctr_offset(anchors, stride, featmap_size):
+ feat_h, feat_w = featmap_size
+ assert len(anchors) == feat_h * feat_w
+
+ x = (anchors[:, 0] + anchors[:, 2]) * 0.5
+ y = (anchors[:, 1] + anchors[:, 3]) * 0.5
+ # compute centers on feature map
+ x = x / stride
+ y = y / stride
+ # compute predefine centers
+ xx = torch.arange(0, feat_w, device=anchors.device)
+ yy = torch.arange(0, feat_h, device=anchors.device)
+ yy, xx = torch.meshgrid(yy, xx)
+ xx = xx.reshape(-1).type_as(x)
+ yy = yy.reshape(-1).type_as(y)
+
+ offset_x = x - xx # (NA, )
+ offset_y = y - yy # (NA, )
+ return offset_x, offset_y
+
+ num_imgs = len(anchor_list)
+ num_lvls = len(anchor_list[0])
+ dtype = anchor_list[0][0].dtype
+ device = anchor_list[0][0].device
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ offset_list = []
+ for i in range(num_imgs):
+ mlvl_offset = []
+ for lvl in range(num_lvls):
+ c_offset_x, c_offset_y = _ctr_offset(anchor_list[i][lvl],
+ anchor_strides[lvl],
+ featmap_sizes[lvl])
+ s_offset_x, s_offset_y = _shape_offset(anchor_list[i][lvl],
+ anchor_strides[lvl])
+
+ # offset = ctr_offset + shape_offset
+ offset_x = s_offset_x + c_offset_x[:, None]
+ offset_y = s_offset_y + c_offset_y[:, None]
+
+ # offset order (y0, x0, y1, x2, .., y8, x8, y9, x9)
+ offset = torch.stack([offset_y, offset_x], dim=-1)
+ offset = offset.reshape(offset.size(0), -1) # [NA, 2*ks**2]
+ mlvl_offset.append(offset)
+ offset_list.append(torch.cat(mlvl_offset)) # [totalNA, 2*ks**2]
+ offset_list = images_to_levels(offset_list, num_level_anchors)
+ return offset_list
+
+ def loss_single(self, cls_score, bbox_pred, anchors, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Loss function on single scale."""
+ # classification loss
+ if self.with_cls:
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ bbox_weights = bbox_weights.reshape(-1, 4)
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ anchors = anchors.reshape(-1, 4)
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_reg = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ if self.with_cls:
+ return loss_cls, loss_reg
+ return None, loss_reg
+
+ def loss(self,
+ anchor_list,
+ valid_flag_list,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ anchor_list (list[list]): Multi level anchors of each image.
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in bbox_preds]
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ featmap_sizes,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ if self.sampling:
+ num_total_samples = num_total_pos + num_total_neg
+ else:
+ # 200 is hard-coded average factor,
+ # which follows guided anchoring.
+ num_total_samples = sum([label.numel()
+ for label in labels_list]) / 200.0
+
+ # change per image, per level anchor_list to per_level, per_image
+ mlvl_anchor_list = list(zip(*anchor_list))
+ # concat mlvl_anchor_list
+ mlvl_anchor_list = [
+ torch.cat(anchors, dim=0) for anchors in mlvl_anchor_list
+ ]
+
+ losses = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+ if self.with_cls:
+ return dict(loss_rpn_cls=losses[0], loss_rpn_reg=losses[1])
+ return dict(loss_rpn_reg=losses[1])
+
+ def get_bboxes(self,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg,
+ rescale=False):
+ """Get proposal predict."""
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ anchor_list[img_id], img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def refine_bboxes(self, anchor_list, bbox_preds, img_metas):
+ """Refine bboxes through stages."""
+ num_levels = len(bbox_preds)
+ new_anchor_list = []
+ for img_id in range(len(img_metas)):
+ mlvl_anchors = []
+ for i in range(num_levels):
+ bbox_pred = bbox_preds[i][img_id].detach()
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ img_shape = img_metas[img_id]['img_shape']
+ bboxes = self.bbox_coder.decode(anchor_list[img_id][i],
+ bbox_pred, img_shape)
+ mlvl_anchors.append(bboxes)
+ new_anchor_list.append(mlvl_anchors)
+ return new_anchor_list
+
+ # TODO: temporary plan
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (num_anchors * 4, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for each scale level
+ with shape (num_total_anchors, 4).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+
+ Returns:
+ Tensor: Labeled boxes have the shape of (n,5), where the
+ first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
+ between 0 and 1.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ cfg = copy.deepcopy(cfg)
+ # bboxes from different level should be independent during NMS,
+ # level_ids are used as labels for batched NMS to separate them
+ level_ids = []
+ mlvl_scores = []
+ mlvl_bbox_preds = []
+ mlvl_valid_anchors = []
+ for idx in range(len(cls_scores)):
+ rpn_cls_score = cls_scores[idx]
+ rpn_bbox_pred = bbox_preds[idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(-1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(-1, 2)
+ # We set FG labels to [0, num_class-1] and BG label to
+ # num_class in RPN head since mmdet v2.5, which is unified to
+ # be consistent with other head since mmdet v2.0. In mmdet v2.0
+ # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
+ scores = rpn_cls_score.softmax(dim=1)[:, 0]
+ rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ anchors = mlvl_anchors[idx]
+ if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
+ # sort is faster than topk
+ # _, topk_inds = scores.topk(cfg.nms_pre)
+ if torch.onnx.is_in_onnx_export():
+ # sort op will be converted to TopK in onnx
+ # and k<=3480 in TensorRT
+ _, topk_inds = scores.topk(cfg.nms_pre)
+ scores = scores[topk_inds]
+ else:
+ ranked_scores, rank_inds = scores.sort(descending=True)
+ topk_inds = rank_inds[:cfg.nms_pre]
+ scores = ranked_scores[:cfg.nms_pre]
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+ mlvl_scores.append(scores)
+ mlvl_bbox_preds.append(rpn_bbox_pred)
+ mlvl_valid_anchors.append(anchors)
+ level_ids.append(
+ scores.new_full((scores.size(0), ), idx, dtype=torch.long))
+
+ scores = torch.cat(mlvl_scores)
+ anchors = torch.cat(mlvl_valid_anchors)
+ rpn_bbox_pred = torch.cat(mlvl_bbox_preds)
+ proposals = self.bbox_coder.decode(
+ anchors, rpn_bbox_pred, max_shape=img_shape)
+ ids = torch.cat(level_ids)
+
+ # Skip nonzero op while exporting to ONNX
+ if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()):
+ w = proposals[:, 2] - proposals[:, 0]
+ h = proposals[:, 3] - proposals[:, 1]
+ valid_inds = torch.nonzero(
+ (w >= cfg.min_bbox_size)
+ & (h >= cfg.min_bbox_size),
+ as_tuple=False).squeeze()
+ if valid_inds.sum().item() != len(proposals):
+ proposals = proposals[valid_inds, :]
+ scores = scores[valid_inds]
+ ids = ids[valid_inds]
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You ' \
+ f'set max_num and ' \
+ f'max_per_img at the same time, but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ 'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
+ f' iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get' \
+ f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the nms_thr ' \
+ f'which will be deprecated.'
+
+ dets, keep = batched_nms(proposals, scores, ids, cfg.nms)
+ return dets[:cfg.max_per_img]
+
+
+@HEADS.register_module()
+class CascadeRPNHead(BaseDenseHead):
+ """The CascadeRPNHead will predict more accurate region proposals, which is
+ required for two-stage detectors (such as Fast/Faster R-CNN). CascadeRPN
+ consists of a sequence of RPNStage to progressively improve the accuracy of
+ the detected proposals.
+
+ More details can be found in ``https://arxiv.org/abs/1909.06720``.
+
+ Args:
+ num_stages (int): number of CascadeRPN stages.
+ stages (list[dict]): list of configs to build the stages.
+ train_cfg (list[dict]): list of configs at training time each stage.
+ test_cfg (dict): config at testing time.
+ """
+
+ def __init__(self, num_stages, stages, train_cfg, test_cfg):
+ super(CascadeRPNHead, self).__init__()
+ assert num_stages == len(stages)
+ self.num_stages = num_stages
+ self.stages = nn.ModuleList()
+ for i in range(len(stages)):
+ train_cfg_i = train_cfg[i] if train_cfg is not None else None
+ stages[i].update(train_cfg=train_cfg_i)
+ stages[i].update(test_cfg=test_cfg)
+ self.stages.append(build_head(stages[i]))
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ def init_weights(self):
+ """Init weight of CascadeRPN."""
+ for i in range(self.num_stages):
+ self.stages[i].init_weights()
+
+ def loss(self):
+ """loss() is implemented in StageCascadeRPNHead."""
+ pass
+
+ def get_bboxes(self):
+ """get_bboxes() is implemented in StageCascadeRPNHead."""
+ pass
+
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None):
+ """Forward train function."""
+ assert gt_labels is None, 'RPN does not require gt_labels'
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in x]
+ device = x[0].device
+ anchor_list, valid_flag_list = self.stages[0].get_anchors(
+ featmap_sizes, img_metas, device=device)
+
+ losses = dict()
+
+ for i in range(self.num_stages):
+ stage = self.stages[i]
+
+ if stage.adapt_cfg['type'] == 'offset':
+ offset_list = stage.anchor_offset(anchor_list,
+ stage.anchor_strides,
+ featmap_sizes)
+ else:
+ offset_list = None
+ x, cls_score, bbox_pred = stage(x, offset_list)
+ rpn_loss_inputs = (anchor_list, valid_flag_list, cls_score,
+ bbox_pred, gt_bboxes, img_metas)
+ stage_loss = stage.loss(*rpn_loss_inputs)
+ for name, value in stage_loss.items():
+ losses['s{}.{}'.format(i, name)] = value
+
+ # refine boxes
+ if i < self.num_stages - 1:
+ anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
+ img_metas)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
+ bbox_pred, img_metas,
+ self.test_cfg)
+ return losses, proposal_list
+
+ def simple_test_rpn(self, x, img_metas):
+ """Simple forward test function."""
+ featmap_sizes = [featmap.size()[-2:] for featmap in x]
+ device = x[0].device
+ anchor_list, _ = self.stages[0].get_anchors(
+ featmap_sizes, img_metas, device=device)
+
+ for i in range(self.num_stages):
+ stage = self.stages[i]
+ if stage.adapt_cfg['type'] == 'offset':
+ offset_list = stage.anchor_offset(anchor_list,
+ stage.anchor_strides,
+ featmap_sizes)
+ else:
+ offset_list = None
+ x, cls_score, bbox_pred = stage(x, offset_list)
+ if i < self.num_stages - 1:
+ anchor_list = stage.refine_bboxes(anchor_list, bbox_pred,
+ img_metas)
+
+ proposal_list = self.stages[-1].get_bboxes(anchor_list, cls_score,
+ bbox_pred, img_metas,
+ self.test_cfg)
+ return proposal_list
+
+ def aug_test_rpn(self, x, img_metas):
+ """Augmented forward test function."""
+ raise NotImplementedError
diff --git a/mmdet/models/dense_heads/centripetal_head.py b/mmdet/models/dense_heads/centripetal_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6728218b60539a71f6353645635f741a1ad7263d
--- /dev/null
+++ b/mmdet/models/dense_heads/centripetal_head.py
@@ -0,0 +1,421 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init
+from mmcv.ops import DeformConv2d
+
+from mmdet.core import multi_apply
+from ..builder import HEADS, build_loss
+from .corner_head import CornerHead
+
+
+@HEADS.register_module()
+class CentripetalHead(CornerHead):
+ """Head of CentripetalNet: Pursuing High-quality Keypoint Pairs for Object
+ Detection.
+
+ CentripetalHead inherits from :class:`CornerHead`. It removes the
+ embedding branch and adds guiding shift and centripetal shift branches.
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ num_feat_levels (int): Levels of feature from the previous module. 2
+ for HourglassNet-104 and 1 for HourglassNet-52. HourglassNet-104
+ outputs the final feature and intermediate supervision feature and
+ HourglassNet-52 only outputs the final feature. Default: 2.
+ corner_emb_channels (int): Channel of embedding vector. Default: 1.
+ train_cfg (dict | None): Training config. Useless in CornerHead,
+ but we keep this variable for SingleStageDetector. Default: None.
+ test_cfg (dict | None): Testing config of CornerHead. Default: None.
+ loss_heatmap (dict | None): Config of corner heatmap loss. Default:
+ GaussianFocalLoss.
+ loss_embedding (dict | None): Config of corner embedding loss. Default:
+ AssociativeEmbeddingLoss.
+ loss_offset (dict | None): Config of corner offset loss. Default:
+ SmoothL1Loss.
+ loss_guiding_shift (dict): Config of guiding shift loss. Default:
+ SmoothL1Loss.
+ loss_centripetal_shift (dict): Config of centripetal shift loss.
+ Default: SmoothL1Loss.
+ """
+
+ def __init__(self,
+ *args,
+ centripetal_shift_channels=2,
+ guiding_shift_channels=2,
+ feat_adaption_conv_kernel=3,
+ loss_guiding_shift=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=0.05),
+ loss_centripetal_shift=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1),
+ **kwargs):
+ assert centripetal_shift_channels == 2, (
+ 'CentripetalHead only support centripetal_shift_channels == 2')
+ self.centripetal_shift_channels = centripetal_shift_channels
+ assert guiding_shift_channels == 2, (
+ 'CentripetalHead only support guiding_shift_channels == 2')
+ self.guiding_shift_channels = guiding_shift_channels
+ self.feat_adaption_conv_kernel = feat_adaption_conv_kernel
+ super(CentripetalHead, self).__init__(*args, **kwargs)
+ self.loss_guiding_shift = build_loss(loss_guiding_shift)
+ self.loss_centripetal_shift = build_loss(loss_centripetal_shift)
+
+ def _init_centripetal_layers(self):
+ """Initialize centripetal layers.
+
+ Including feature adaption deform convs (feat_adaption), deform offset
+ prediction convs (dcn_off), guiding shift (guiding_shift) and
+ centripetal shift ( centripetal_shift). Each branch has two parts:
+ prefix `tl_` for top-left and `br_` for bottom-right.
+ """
+ self.tl_feat_adaption = nn.ModuleList()
+ self.br_feat_adaption = nn.ModuleList()
+ self.tl_dcn_offset = nn.ModuleList()
+ self.br_dcn_offset = nn.ModuleList()
+ self.tl_guiding_shift = nn.ModuleList()
+ self.br_guiding_shift = nn.ModuleList()
+ self.tl_centripetal_shift = nn.ModuleList()
+ self.br_centripetal_shift = nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_feat_adaption.append(
+ DeformConv2d(self.in_channels, self.in_channels,
+ self.feat_adaption_conv_kernel, 1, 1))
+ self.br_feat_adaption.append(
+ DeformConv2d(self.in_channels, self.in_channels,
+ self.feat_adaption_conv_kernel, 1, 1))
+
+ self.tl_guiding_shift.append(
+ self._make_layers(
+ out_channels=self.guiding_shift_channels,
+ in_channels=self.in_channels))
+ self.br_guiding_shift.append(
+ self._make_layers(
+ out_channels=self.guiding_shift_channels,
+ in_channels=self.in_channels))
+
+ self.tl_dcn_offset.append(
+ ConvModule(
+ self.guiding_shift_channels,
+ self.feat_adaption_conv_kernel**2 *
+ self.guiding_shift_channels,
+ 1,
+ bias=False,
+ act_cfg=None))
+ self.br_dcn_offset.append(
+ ConvModule(
+ self.guiding_shift_channels,
+ self.feat_adaption_conv_kernel**2 *
+ self.guiding_shift_channels,
+ 1,
+ bias=False,
+ act_cfg=None))
+
+ self.tl_centripetal_shift.append(
+ self._make_layers(
+ out_channels=self.centripetal_shift_channels,
+ in_channels=self.in_channels))
+ self.br_centripetal_shift.append(
+ self._make_layers(
+ out_channels=self.centripetal_shift_channels,
+ in_channels=self.in_channels))
+
+ def _init_layers(self):
+ """Initialize layers for CentripetalHead.
+
+ Including two parts: CornerHead layers and CentripetalHead layers
+ """
+ super()._init_layers() # using _init_layers in CornerHead
+ self._init_centripetal_layers()
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ super().init_weights()
+ for i in range(self.num_feat_levels):
+ normal_init(self.tl_feat_adaption[i], std=0.01)
+ normal_init(self.br_feat_adaption[i], std=0.01)
+ normal_init(self.tl_dcn_offset[i].conv, std=0.1)
+ normal_init(self.br_dcn_offset[i].conv, std=0.1)
+ _ = [x.conv.reset_parameters() for x in self.tl_guiding_shift[i]]
+ _ = [x.conv.reset_parameters() for x in self.br_guiding_shift[i]]
+ _ = [
+ x.conv.reset_parameters() for x in self.tl_centripetal_shift[i]
+ ]
+ _ = [
+ x.conv.reset_parameters() for x in self.br_centripetal_shift[i]
+ ]
+
+ def forward_single(self, x, lvl_ind):
+ """Forward feature of a single level.
+
+ Args:
+ x (Tensor): Feature of a single level.
+ lvl_ind (int): Level index of current feature.
+
+ Returns:
+ tuple[Tensor]: A tuple of CentripetalHead's output for current
+ feature level. Containing the following Tensors:
+
+ - tl_heat (Tensor): Predicted top-left corner heatmap.
+ - br_heat (Tensor): Predicted bottom-right corner heatmap.
+ - tl_off (Tensor): Predicted top-left offset heatmap.
+ - br_off (Tensor): Predicted bottom-right offset heatmap.
+ - tl_guiding_shift (Tensor): Predicted top-left guiding shift
+ heatmap.
+ - br_guiding_shift (Tensor): Predicted bottom-right guiding
+ shift heatmap.
+ - tl_centripetal_shift (Tensor): Predicted top-left centripetal
+ shift heatmap.
+ - br_centripetal_shift (Tensor): Predicted bottom-right
+ centripetal shift heatmap.
+ """
+ tl_heat, br_heat, _, _, tl_off, br_off, tl_pool, br_pool = super(
+ ).forward_single(
+ x, lvl_ind, return_pool=True)
+
+ tl_guiding_shift = self.tl_guiding_shift[lvl_ind](tl_pool)
+ br_guiding_shift = self.br_guiding_shift[lvl_ind](br_pool)
+
+ tl_dcn_offset = self.tl_dcn_offset[lvl_ind](tl_guiding_shift.detach())
+ br_dcn_offset = self.br_dcn_offset[lvl_ind](br_guiding_shift.detach())
+
+ tl_feat_adaption = self.tl_feat_adaption[lvl_ind](tl_pool,
+ tl_dcn_offset)
+ br_feat_adaption = self.br_feat_adaption[lvl_ind](br_pool,
+ br_dcn_offset)
+
+ tl_centripetal_shift = self.tl_centripetal_shift[lvl_ind](
+ tl_feat_adaption)
+ br_centripetal_shift = self.br_centripetal_shift[lvl_ind](
+ br_feat_adaption)
+
+ result_list = [
+ tl_heat, br_heat, tl_off, br_off, tl_guiding_shift,
+ br_guiding_shift, tl_centripetal_shift, br_centripetal_shift
+ ]
+ return result_list
+
+ def loss(self,
+ tl_heats,
+ br_heats,
+ tl_offs,
+ br_offs,
+ tl_guiding_shifts,
+ br_guiding_shifts,
+ tl_centripetal_shifts,
+ br_centripetal_shifts,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
+ level with shape (N, guiding_shift_channels, H, W).
+ br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
+ each level with shape (N, guiding_shift_channels, H, W).
+ tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
+ for each level with shape (N, centripetal_shift_channels, H,
+ W).
+ br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
+ shifts for each level with shape (N,
+ centripetal_shift_channels, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [left, top, right, bottom] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components. Containing the
+ following losses:
+
+ - det_loss (list[Tensor]): Corner keypoint losses of all
+ feature levels.
+ - off_loss (list[Tensor]): Corner offset losses of all feature
+ levels.
+ - guiding_loss (list[Tensor]): Guiding shift losses of all
+ feature levels.
+ - centripetal_loss (list[Tensor]): Centripetal shift losses of
+ all feature levels.
+ """
+ targets = self.get_targets(
+ gt_bboxes,
+ gt_labels,
+ tl_heats[-1].shape,
+ img_metas[0]['pad_shape'],
+ with_corner_emb=self.with_corner_emb,
+ with_guiding_shift=True,
+ with_centripetal_shift=True)
+ mlvl_targets = [targets for _ in range(self.num_feat_levels)]
+ [det_losses, off_losses, guiding_losses, centripetal_losses
+ ] = multi_apply(self.loss_single, tl_heats, br_heats, tl_offs,
+ br_offs, tl_guiding_shifts, br_guiding_shifts,
+ tl_centripetal_shifts, br_centripetal_shifts,
+ mlvl_targets)
+ loss_dict = dict(
+ det_loss=det_losses,
+ off_loss=off_losses,
+ guiding_loss=guiding_losses,
+ centripetal_loss=centripetal_losses)
+ return loss_dict
+
+ def loss_single(self, tl_hmp, br_hmp, tl_off, br_off, tl_guiding_shift,
+ br_guiding_shift, tl_centripetal_shift,
+ br_centripetal_shift, targets):
+ """Compute losses for single level.
+
+ Args:
+ tl_hmp (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_hmp (Tensor): Bottom-right corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ tl_guiding_shift (Tensor): Top-left guiding shift for current level
+ with shape (N, guiding_shift_channels, H, W).
+ br_guiding_shift (Tensor): Bottom-right guiding shift for current
+ level with shape (N, guiding_shift_channels, H, W).
+ tl_centripetal_shift (Tensor): Top-left centripetal shift for
+ current level with shape (N, centripetal_shift_channels, H, W).
+ br_centripetal_shift (Tensor): Bottom-right centripetal shift for
+ current level with shape (N, centripetal_shift_channels, H, W).
+ targets (dict): Corner target generated by `get_targets`.
+
+ Returns:
+ tuple[torch.Tensor]: Losses of the head's differnet branches
+ containing the following losses:
+
+ - det_loss (Tensor): Corner keypoint loss.
+ - off_loss (Tensor): Corner offset loss.
+ - guiding_loss (Tensor): Guiding shift loss.
+ - centripetal_loss (Tensor): Centripetal shift loss.
+ """
+ targets['corner_embedding'] = None
+
+ det_loss, _, _, off_loss = super().loss_single(tl_hmp, br_hmp, None,
+ None, tl_off, br_off,
+ targets)
+
+ gt_tl_guiding_shift = targets['topleft_guiding_shift']
+ gt_br_guiding_shift = targets['bottomright_guiding_shift']
+ gt_tl_centripetal_shift = targets['topleft_centripetal_shift']
+ gt_br_centripetal_shift = targets['bottomright_centripetal_shift']
+
+ gt_tl_heatmap = targets['topleft_heatmap']
+ gt_br_heatmap = targets['bottomright_heatmap']
+ # We only compute the offset loss at the real corner position.
+ # The value of real corner would be 1 in heatmap ground truth.
+ # The mask is computed in class agnostic mode and its shape is
+ # batch * 1 * width * height.
+ tl_mask = gt_tl_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_tl_heatmap)
+ br_mask = gt_br_heatmap.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_br_heatmap)
+
+ # Guiding shift loss
+ tl_guiding_loss = self.loss_guiding_shift(
+ tl_guiding_shift,
+ gt_tl_guiding_shift,
+ tl_mask,
+ avg_factor=tl_mask.sum())
+ br_guiding_loss = self.loss_guiding_shift(
+ br_guiding_shift,
+ gt_br_guiding_shift,
+ br_mask,
+ avg_factor=br_mask.sum())
+ guiding_loss = (tl_guiding_loss + br_guiding_loss) / 2.0
+ # Centripetal shift loss
+ tl_centripetal_loss = self.loss_centripetal_shift(
+ tl_centripetal_shift,
+ gt_tl_centripetal_shift,
+ tl_mask,
+ avg_factor=tl_mask.sum())
+ br_centripetal_loss = self.loss_centripetal_shift(
+ br_centripetal_shift,
+ gt_br_centripetal_shift,
+ br_mask,
+ avg_factor=br_mask.sum())
+ centripetal_loss = (tl_centripetal_loss + br_centripetal_loss) / 2.0
+
+ return det_loss, off_loss, guiding_loss, centripetal_loss
+
+ def get_bboxes(self,
+ tl_heats,
+ br_heats,
+ tl_offs,
+ br_offs,
+ tl_guiding_shifts,
+ br_guiding_shifts,
+ tl_centripetal_shifts,
+ br_centripetal_shifts,
+ img_metas,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ tl_guiding_shifts (list[Tensor]): Top-left guiding shifts for each
+ level with shape (N, guiding_shift_channels, H, W). Useless in
+ this function, we keep this arg because it's the raw output
+ from CentripetalHead.
+ br_guiding_shifts (list[Tensor]): Bottom-right guiding shifts for
+ each level with shape (N, guiding_shift_channels, H, W).
+ Useless in this function, we keep this arg because it's the
+ raw output from CentripetalHead.
+ tl_centripetal_shifts (list[Tensor]): Top-left centripetal shifts
+ for each level with shape (N, centripetal_shift_channels, H,
+ W).
+ br_centripetal_shifts (list[Tensor]): Bottom-right centripetal
+ shifts for each level with shape (N,
+ centripetal_shift_channels, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ tl_heats[-1][img_id:img_id + 1, :],
+ br_heats[-1][img_id:img_id + 1, :],
+ tl_offs[-1][img_id:img_id + 1, :],
+ br_offs[-1][img_id:img_id + 1, :],
+ img_metas[img_id],
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=tl_centripetal_shifts[-1][
+ img_id:img_id + 1, :],
+ br_centripetal_shift=br_centripetal_shifts[-1][
+ img_id:img_id + 1, :],
+ rescale=rescale,
+ with_nms=with_nms))
+
+ return result_list
diff --git a/mmdet/models/dense_heads/corner_head.py b/mmdet/models/dense_heads/corner_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..50cdb49a29f2ced1a31a50e654a3bdc14f5f5004
--- /dev/null
+++ b/mmdet/models/dense_heads/corner_head.py
@@ -0,0 +1,1074 @@
+from logging import warning
+from math import ceil, log
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, bias_init_with_prob
+from mmcv.ops import CornerPool, batched_nms
+
+from mmdet.core import multi_apply
+from ..builder import HEADS, build_loss
+from ..utils import gaussian_radius, gen_gaussian_target
+from .base_dense_head import BaseDenseHead
+
+
+class BiCornerPool(nn.Module):
+ """Bidirectional Corner Pooling Module (TopLeft, BottomRight, etc.)
+
+ Args:
+ in_channels (int): Input channels of module.
+ out_channels (int): Output channels of module.
+ feat_channels (int): Feature channels of module.
+ directions (list[str]): Directions of two CornerPools.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ directions,
+ feat_channels=128,
+ out_channels=128,
+ norm_cfg=dict(type='BN', requires_grad=True)):
+ super(BiCornerPool, self).__init__()
+ self.direction1_conv = ConvModule(
+ in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
+ self.direction2_conv = ConvModule(
+ in_channels, feat_channels, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.aftpool_conv = ConvModule(
+ feat_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.conv1 = ConvModule(
+ in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
+ self.conv2 = ConvModule(
+ in_channels, out_channels, 3, padding=1, norm_cfg=norm_cfg)
+
+ self.direction1_pool = CornerPool(directions[0])
+ self.direction2_pool = CornerPool(directions[1])
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ """Forward features from the upstream network.
+
+ Args:
+ x (tensor): Input feature of BiCornerPool.
+
+ Returns:
+ conv2 (tensor): Output feature of BiCornerPool.
+ """
+ direction1_conv = self.direction1_conv(x)
+ direction2_conv = self.direction2_conv(x)
+ direction1_feat = self.direction1_pool(direction1_conv)
+ direction2_feat = self.direction2_pool(direction2_conv)
+ aftpool_conv = self.aftpool_conv(direction1_feat + direction2_feat)
+ conv1 = self.conv1(x)
+ relu = self.relu(aftpool_conv + conv1)
+ conv2 = self.conv2(relu)
+ return conv2
+
+
+@HEADS.register_module()
+class CornerHead(BaseDenseHead):
+ """Head of CornerNet: Detecting Objects as Paired Keypoints.
+
+ Code is modified from the `official github repo
+ `_ .
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ num_feat_levels (int): Levels of feature from the previous module. 2
+ for HourglassNet-104 and 1 for HourglassNet-52. Because
+ HourglassNet-104 outputs the final feature and intermediate
+ supervision feature and HourglassNet-52 only outputs the final
+ feature. Default: 2.
+ corner_emb_channels (int): Channel of embedding vector. Default: 1.
+ train_cfg (dict | None): Training config. Useless in CornerHead,
+ but we keep this variable for SingleStageDetector. Default: None.
+ test_cfg (dict | None): Testing config of CornerHead. Default: None.
+ loss_heatmap (dict | None): Config of corner heatmap loss. Default:
+ GaussianFocalLoss.
+ loss_embedding (dict | None): Config of corner embedding loss. Default:
+ AssociativeEmbeddingLoss.
+ loss_offset (dict | None): Config of corner offset loss. Default:
+ SmoothL1Loss.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_feat_levels=2,
+ corner_emb_channels=1,
+ train_cfg=None,
+ test_cfg=None,
+ loss_heatmap=dict(
+ type='GaussianFocalLoss',
+ alpha=2.0,
+ gamma=4.0,
+ loss_weight=1),
+ loss_embedding=dict(
+ type='AssociativeEmbeddingLoss',
+ pull_weight=0.25,
+ push_weight=0.25),
+ loss_offset=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1)):
+ super(CornerHead, self).__init__()
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.corner_emb_channels = corner_emb_channels
+ self.with_corner_emb = self.corner_emb_channels > 0
+ self.corner_offset_channels = 2
+ self.num_feat_levels = num_feat_levels
+ self.loss_heatmap = build_loss(
+ loss_heatmap) if loss_heatmap is not None else None
+ self.loss_embedding = build_loss(
+ loss_embedding) if loss_embedding is not None else None
+ self.loss_offset = build_loss(
+ loss_offset) if loss_offset is not None else None
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self._init_layers()
+
+ def _make_layers(self, out_channels, in_channels=256, feat_channels=256):
+ """Initialize conv sequential for CornerHead."""
+ return nn.Sequential(
+ ConvModule(in_channels, feat_channels, 3, padding=1),
+ ConvModule(
+ feat_channels, out_channels, 1, norm_cfg=None, act_cfg=None))
+
+ def _init_corner_kpt_layers(self):
+ """Initialize corner keypoint layers.
+
+ Including corner heatmap branch and corner offset branch. Each branch
+ has two parts: prefix `tl_` for top-left and `br_` for bottom-right.
+ """
+ self.tl_pool, self.br_pool = nn.ModuleList(), nn.ModuleList()
+ self.tl_heat, self.br_heat = nn.ModuleList(), nn.ModuleList()
+ self.tl_off, self.br_off = nn.ModuleList(), nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_pool.append(
+ BiCornerPool(
+ self.in_channels, ['top', 'left'],
+ out_channels=self.in_channels))
+ self.br_pool.append(
+ BiCornerPool(
+ self.in_channels, ['bottom', 'right'],
+ out_channels=self.in_channels))
+
+ self.tl_heat.append(
+ self._make_layers(
+ out_channels=self.num_classes,
+ in_channels=self.in_channels))
+ self.br_heat.append(
+ self._make_layers(
+ out_channels=self.num_classes,
+ in_channels=self.in_channels))
+
+ self.tl_off.append(
+ self._make_layers(
+ out_channels=self.corner_offset_channels,
+ in_channels=self.in_channels))
+ self.br_off.append(
+ self._make_layers(
+ out_channels=self.corner_offset_channels,
+ in_channels=self.in_channels))
+
+ def _init_corner_emb_layers(self):
+ """Initialize corner embedding layers.
+
+ Only include corner embedding branch with two parts: prefix `tl_` for
+ top-left and `br_` for bottom-right.
+ """
+ self.tl_emb, self.br_emb = nn.ModuleList(), nn.ModuleList()
+
+ for _ in range(self.num_feat_levels):
+ self.tl_emb.append(
+ self._make_layers(
+ out_channels=self.corner_emb_channels,
+ in_channels=self.in_channels))
+ self.br_emb.append(
+ self._make_layers(
+ out_channels=self.corner_emb_channels,
+ in_channels=self.in_channels))
+
+ def _init_layers(self):
+ """Initialize layers for CornerHead.
+
+ Including two parts: corner keypoint layers and corner embedding layers
+ """
+ self._init_corner_kpt_layers()
+ if self.with_corner_emb:
+ self._init_corner_emb_layers()
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ bias_init = bias_init_with_prob(0.1)
+ for i in range(self.num_feat_levels):
+ # The initialization of parameters are different between nn.Conv2d
+ # and ConvModule. Our experiments show that using the original
+ # initialization of nn.Conv2d increases the final mAP by about 0.2%
+ self.tl_heat[i][-1].conv.reset_parameters()
+ self.tl_heat[i][-1].conv.bias.data.fill_(bias_init)
+ self.br_heat[i][-1].conv.reset_parameters()
+ self.br_heat[i][-1].conv.bias.data.fill_(bias_init)
+ self.tl_off[i][-1].conv.reset_parameters()
+ self.br_off[i][-1].conv.reset_parameters()
+ if self.with_corner_emb:
+ self.tl_emb[i][-1].conv.reset_parameters()
+ self.br_emb[i][-1].conv.reset_parameters()
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of corner heatmaps, offset heatmaps and
+ embedding heatmaps.
+ - tl_heats (list[Tensor]): Top-left corner heatmaps for all
+ levels, each is a 4D-tensor, the channels number is
+ num_classes.
+ - br_heats (list[Tensor]): Bottom-right corner heatmaps for all
+ levels, each is a 4D-tensor, the channels number is
+ num_classes.
+ - tl_embs (list[Tensor] | list[None]): Top-left embedding
+ heatmaps for all levels, each is a 4D-tensor or None.
+ If not None, the channels number is corner_emb_channels.
+ - br_embs (list[Tensor] | list[None]): Bottom-right embedding
+ heatmaps for all levels, each is a 4D-tensor or None.
+ If not None, the channels number is corner_emb_channels.
+ - tl_offs (list[Tensor]): Top-left offset heatmaps for all
+ levels, each is a 4D-tensor. The channels number is
+ corner_offset_channels.
+ - br_offs (list[Tensor]): Bottom-right offset heatmaps for all
+ levels, each is a 4D-tensor. The channels number is
+ corner_offset_channels.
+ """
+ lvl_ind = list(range(self.num_feat_levels))
+ return multi_apply(self.forward_single, feats, lvl_ind)
+
+ def forward_single(self, x, lvl_ind, return_pool=False):
+ """Forward feature of a single level.
+
+ Args:
+ x (Tensor): Feature of a single level.
+ lvl_ind (int): Level index of current feature.
+ return_pool (bool): Return corner pool feature or not.
+
+ Returns:
+ tuple[Tensor]: A tuple of CornerHead's output for current feature
+ level. Containing the following Tensors:
+
+ - tl_heat (Tensor): Predicted top-left corner heatmap.
+ - br_heat (Tensor): Predicted bottom-right corner heatmap.
+ - tl_emb (Tensor | None): Predicted top-left embedding heatmap.
+ None for `self.with_corner_emb == False`.
+ - br_emb (Tensor | None): Predicted bottom-right embedding
+ heatmap. None for `self.with_corner_emb == False`.
+ - tl_off (Tensor): Predicted top-left offset heatmap.
+ - br_off (Tensor): Predicted bottom-right offset heatmap.
+ - tl_pool (Tensor): Top-left corner pool feature. Not must
+ have.
+ - br_pool (Tensor): Bottom-right corner pool feature. Not must
+ have.
+ """
+ tl_pool = self.tl_pool[lvl_ind](x)
+ tl_heat = self.tl_heat[lvl_ind](tl_pool)
+ br_pool = self.br_pool[lvl_ind](x)
+ br_heat = self.br_heat[lvl_ind](br_pool)
+
+ tl_emb, br_emb = None, None
+ if self.with_corner_emb:
+ tl_emb = self.tl_emb[lvl_ind](tl_pool)
+ br_emb = self.br_emb[lvl_ind](br_pool)
+
+ tl_off = self.tl_off[lvl_ind](tl_pool)
+ br_off = self.br_off[lvl_ind](br_pool)
+
+ result_list = [tl_heat, br_heat, tl_emb, br_emb, tl_off, br_off]
+ if return_pool:
+ result_list.append(tl_pool)
+ result_list.append(br_pool)
+
+ return result_list
+
+ def get_targets(self,
+ gt_bboxes,
+ gt_labels,
+ feat_shape,
+ img_shape,
+ with_corner_emb=False,
+ with_guiding_shift=False,
+ with_centripetal_shift=False):
+ """Generate corner targets.
+
+ Including corner heatmap, corner offset.
+
+ Optional: corner embedding, corner guiding shift, centripetal shift.
+
+ For CornerNet, we generate corner heatmap, corner offset and corner
+ embedding from this function.
+
+ For CentripetalNet, we generate corner heatmap, corner offset, guiding
+ shift and centripetal shift from this function.
+
+ Args:
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image, each
+ has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box, each has
+ shape (num_gt,).
+ feat_shape (list[int]): Shape of output feature,
+ [batch, channel, height, width].
+ img_shape (list[int]): Shape of input image,
+ [height, width, channel].
+ with_corner_emb (bool): Generate corner embedding target or not.
+ Default: False.
+ with_guiding_shift (bool): Generate guiding shift target or not.
+ Default: False.
+ with_centripetal_shift (bool): Generate centripetal shift target or
+ not. Default: False.
+
+ Returns:
+ dict: Ground truth of corner heatmap, corner offset, corner
+ embedding, guiding shift and centripetal shift. Containing the
+ following keys:
+
+ - topleft_heatmap (Tensor): Ground truth top-left corner
+ heatmap.
+ - bottomright_heatmap (Tensor): Ground truth bottom-right
+ corner heatmap.
+ - topleft_offset (Tensor): Ground truth top-left corner offset.
+ - bottomright_offset (Tensor): Ground truth bottom-right corner
+ offset.
+ - corner_embedding (list[list[list[int]]]): Ground truth corner
+ embedding. Not must have.
+ - topleft_guiding_shift (Tensor): Ground truth top-left corner
+ guiding shift. Not must have.
+ - bottomright_guiding_shift (Tensor): Ground truth bottom-right
+ corner guiding shift. Not must have.
+ - topleft_centripetal_shift (Tensor): Ground truth top-left
+ corner centripetal shift. Not must have.
+ - bottomright_centripetal_shift (Tensor): Ground truth
+ bottom-right corner centripetal shift. Not must have.
+ """
+ batch_size, _, height, width = feat_shape
+ img_h, img_w = img_shape[:2]
+
+ width_ratio = float(width / img_w)
+ height_ratio = float(height / img_h)
+
+ gt_tl_heatmap = gt_bboxes[-1].new_zeros(
+ [batch_size, self.num_classes, height, width])
+ gt_br_heatmap = gt_bboxes[-1].new_zeros(
+ [batch_size, self.num_classes, height, width])
+ gt_tl_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
+ gt_br_offset = gt_bboxes[-1].new_zeros([batch_size, 2, height, width])
+
+ if with_corner_emb:
+ match = []
+
+ # Guiding shift is a kind of offset, from center to corner
+ if with_guiding_shift:
+ gt_tl_guiding_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ gt_br_guiding_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ # Centripetal shift is also a kind of offset, from center to corner
+ # and normalized by log.
+ if with_centripetal_shift:
+ gt_tl_centripetal_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+ gt_br_centripetal_shift = gt_bboxes[-1].new_zeros(
+ [batch_size, 2, height, width])
+
+ for batch_id in range(batch_size):
+ # Ground truth of corner embedding per image is a list of coord set
+ corner_match = []
+ for box_id in range(len(gt_labels[batch_id])):
+ left, top, right, bottom = gt_bboxes[batch_id][box_id]
+ center_x = (left + right) / 2.0
+ center_y = (top + bottom) / 2.0
+ label = gt_labels[batch_id][box_id]
+
+ # Use coords in the feature level to generate ground truth
+ scale_left = left * width_ratio
+ scale_right = right * width_ratio
+ scale_top = top * height_ratio
+ scale_bottom = bottom * height_ratio
+ scale_center_x = center_x * width_ratio
+ scale_center_y = center_y * height_ratio
+
+ # Int coords on feature map/ground truth tensor
+ left_idx = int(min(scale_left, width - 1))
+ right_idx = int(min(scale_right, width - 1))
+ top_idx = int(min(scale_top, height - 1))
+ bottom_idx = int(min(scale_bottom, height - 1))
+
+ # Generate gaussian heatmap
+ scale_box_width = ceil(scale_right - scale_left)
+ scale_box_height = ceil(scale_bottom - scale_top)
+ radius = gaussian_radius((scale_box_height, scale_box_width),
+ min_overlap=0.3)
+ radius = max(0, int(radius))
+ gt_tl_heatmap[batch_id, label] = gen_gaussian_target(
+ gt_tl_heatmap[batch_id, label], [left_idx, top_idx],
+ radius)
+ gt_br_heatmap[batch_id, label] = gen_gaussian_target(
+ gt_br_heatmap[batch_id, label], [right_idx, bottom_idx],
+ radius)
+
+ # Generate corner offset
+ left_offset = scale_left - left_idx
+ top_offset = scale_top - top_idx
+ right_offset = scale_right - right_idx
+ bottom_offset = scale_bottom - bottom_idx
+ gt_tl_offset[batch_id, 0, top_idx, left_idx] = left_offset
+ gt_tl_offset[batch_id, 1, top_idx, left_idx] = top_offset
+ gt_br_offset[batch_id, 0, bottom_idx, right_idx] = right_offset
+ gt_br_offset[batch_id, 1, bottom_idx,
+ right_idx] = bottom_offset
+
+ # Generate corner embedding
+ if with_corner_emb:
+ corner_match.append([[top_idx, left_idx],
+ [bottom_idx, right_idx]])
+ # Generate guiding shift
+ if with_guiding_shift:
+ gt_tl_guiding_shift[batch_id, 0, top_idx,
+ left_idx] = scale_center_x - left_idx
+ gt_tl_guiding_shift[batch_id, 1, top_idx,
+ left_idx] = scale_center_y - top_idx
+ gt_br_guiding_shift[batch_id, 0, bottom_idx,
+ right_idx] = right_idx - scale_center_x
+ gt_br_guiding_shift[
+ batch_id, 1, bottom_idx,
+ right_idx] = bottom_idx - scale_center_y
+ # Generate centripetal shift
+ if with_centripetal_shift:
+ gt_tl_centripetal_shift[batch_id, 0, top_idx,
+ left_idx] = log(scale_center_x -
+ scale_left)
+ gt_tl_centripetal_shift[batch_id, 1, top_idx,
+ left_idx] = log(scale_center_y -
+ scale_top)
+ gt_br_centripetal_shift[batch_id, 0, bottom_idx,
+ right_idx] = log(scale_right -
+ scale_center_x)
+ gt_br_centripetal_shift[batch_id, 1, bottom_idx,
+ right_idx] = log(scale_bottom -
+ scale_center_y)
+
+ if with_corner_emb:
+ match.append(corner_match)
+
+ target_result = dict(
+ topleft_heatmap=gt_tl_heatmap,
+ topleft_offset=gt_tl_offset,
+ bottomright_heatmap=gt_br_heatmap,
+ bottomright_offset=gt_br_offset)
+
+ if with_corner_emb:
+ target_result.update(corner_embedding=match)
+ if with_guiding_shift:
+ target_result.update(
+ topleft_guiding_shift=gt_tl_guiding_shift,
+ bottomright_guiding_shift=gt_br_guiding_shift)
+ if with_centripetal_shift:
+ target_result.update(
+ topleft_centripetal_shift=gt_tl_centripetal_shift,
+ bottomright_centripetal_shift=gt_br_centripetal_shift)
+
+ return target_result
+
+ def loss(self,
+ tl_heats,
+ br_heats,
+ tl_embs,
+ br_embs,
+ tl_offs,
+ br_offs,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_embs (list[Tensor]): Top-left corner embeddings for each level
+ with shape (N, corner_emb_channels, H, W).
+ br_embs (list[Tensor]): Bottom-right corner embeddings for each
+ level with shape (N, corner_emb_channels, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [left, top, right, bottom] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components. Containing the
+ following losses:
+
+ - det_loss (list[Tensor]): Corner keypoint losses of all
+ feature levels.
+ - pull_loss (list[Tensor]): Part one of AssociativeEmbedding
+ losses of all feature levels.
+ - push_loss (list[Tensor]): Part two of AssociativeEmbedding
+ losses of all feature levels.
+ - off_loss (list[Tensor]): Corner offset losses of all feature
+ levels.
+ """
+ targets = self.get_targets(
+ gt_bboxes,
+ gt_labels,
+ tl_heats[-1].shape,
+ img_metas[0]['pad_shape'],
+ with_corner_emb=self.with_corner_emb)
+ mlvl_targets = [targets for _ in range(self.num_feat_levels)]
+ det_losses, pull_losses, push_losses, off_losses = multi_apply(
+ self.loss_single, tl_heats, br_heats, tl_embs, br_embs, tl_offs,
+ br_offs, mlvl_targets)
+ loss_dict = dict(det_loss=det_losses, off_loss=off_losses)
+ if self.with_corner_emb:
+ loss_dict.update(pull_loss=pull_losses, push_loss=push_losses)
+ return loss_dict
+
+ def loss_single(self, tl_hmp, br_hmp, tl_emb, br_emb, tl_off, br_off,
+ targets):
+ """Compute losses for single level.
+
+ Args:
+ tl_hmp (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_hmp (Tensor): Bottom-right corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ tl_emb (Tensor): Top-left corner embedding for current level with
+ shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor): Bottom-right corner embedding for current level
+ with shape (N, corner_emb_channels, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ targets (dict): Corner target generated by `get_targets`.
+
+ Returns:
+ tuple[torch.Tensor]: Losses of the head's differnet branches
+ containing the following losses:
+
+ - det_loss (Tensor): Corner keypoint loss.
+ - pull_loss (Tensor): Part one of AssociativeEmbedding loss.
+ - push_loss (Tensor): Part two of AssociativeEmbedding loss.
+ - off_loss (Tensor): Corner offset loss.
+ """
+ gt_tl_hmp = targets['topleft_heatmap']
+ gt_br_hmp = targets['bottomright_heatmap']
+ gt_tl_off = targets['topleft_offset']
+ gt_br_off = targets['bottomright_offset']
+ gt_embedding = targets['corner_embedding']
+
+ # Detection loss
+ tl_det_loss = self.loss_heatmap(
+ tl_hmp.sigmoid(),
+ gt_tl_hmp,
+ avg_factor=max(1,
+ gt_tl_hmp.eq(1).sum()))
+ br_det_loss = self.loss_heatmap(
+ br_hmp.sigmoid(),
+ gt_br_hmp,
+ avg_factor=max(1,
+ gt_br_hmp.eq(1).sum()))
+ det_loss = (tl_det_loss + br_det_loss) / 2.0
+
+ # AssociativeEmbedding loss
+ if self.with_corner_emb and self.loss_embedding is not None:
+ pull_loss, push_loss = self.loss_embedding(tl_emb, br_emb,
+ gt_embedding)
+ else:
+ pull_loss, push_loss = None, None
+
+ # Offset loss
+ # We only compute the offset loss at the real corner position.
+ # The value of real corner would be 1 in heatmap ground truth.
+ # The mask is computed in class agnostic mode and its shape is
+ # batch * 1 * width * height.
+ tl_off_mask = gt_tl_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_tl_hmp)
+ br_off_mask = gt_br_hmp.eq(1).sum(1).gt(0).unsqueeze(1).type_as(
+ gt_br_hmp)
+ tl_off_loss = self.loss_offset(
+ tl_off,
+ gt_tl_off,
+ tl_off_mask,
+ avg_factor=max(1, tl_off_mask.sum()))
+ br_off_loss = self.loss_offset(
+ br_off,
+ gt_br_off,
+ br_off_mask,
+ avg_factor=max(1, br_off_mask.sum()))
+
+ off_loss = (tl_off_loss + br_off_loss) / 2.0
+
+ return det_loss, pull_loss, push_loss, off_loss
+
+ def get_bboxes(self,
+ tl_heats,
+ br_heats,
+ tl_embs,
+ br_embs,
+ tl_offs,
+ br_offs,
+ img_metas,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ tl_heats (list[Tensor]): Top-left corner heatmaps for each level
+ with shape (N, num_classes, H, W).
+ br_heats (list[Tensor]): Bottom-right corner heatmaps for each
+ level with shape (N, num_classes, H, W).
+ tl_embs (list[Tensor]): Top-left corner embeddings for each level
+ with shape (N, corner_emb_channels, H, W).
+ br_embs (list[Tensor]): Bottom-right corner embeddings for each
+ level with shape (N, corner_emb_channels, H, W).
+ tl_offs (list[Tensor]): Top-left corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ br_offs (list[Tensor]): Bottom-right corner offsets for each level
+ with shape (N, corner_offset_channels, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ assert tl_heats[-1].shape[0] == br_heats[-1].shape[0] == len(img_metas)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ result_list.append(
+ self._get_bboxes_single(
+ tl_heats[-1][img_id:img_id + 1, :],
+ br_heats[-1][img_id:img_id + 1, :],
+ tl_offs[-1][img_id:img_id + 1, :],
+ br_offs[-1][img_id:img_id + 1, :],
+ img_metas[img_id],
+ tl_emb=tl_embs[-1][img_id:img_id + 1, :],
+ br_emb=br_embs[-1][img_id:img_id + 1, :],
+ rescale=rescale,
+ with_nms=with_nms))
+
+ return result_list
+
+ def _get_bboxes_single(self,
+ tl_heat,
+ br_heat,
+ tl_off,
+ br_off,
+ img_meta,
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=None,
+ br_centripetal_shift=None,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ tl_heat (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_heat (Tensor): Bottom-right corner heatmap for current level
+ with shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ img_meta (dict): Meta information of current image, e.g.,
+ image size, scaling factor, etc.
+ tl_emb (Tensor): Top-left corner embedding for current level with
+ shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor): Bottom-right corner embedding for current level
+ with shape (N, corner_emb_channels, H, W).
+ tl_centripetal_shift: Top-left corner's centripetal shift for
+ current level with shape (N, 2, H, W).
+ br_centripetal_shift: Bottom-right corner's centripetal shift for
+ current level with shape (N, 2, H, W).
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+ """
+ if isinstance(img_meta, (list, tuple)):
+ img_meta = img_meta[0]
+
+ batch_bboxes, batch_scores, batch_clses = self.decode_heatmap(
+ tl_heat=tl_heat.sigmoid(),
+ br_heat=br_heat.sigmoid(),
+ tl_off=tl_off,
+ br_off=br_off,
+ tl_emb=tl_emb,
+ br_emb=br_emb,
+ tl_centripetal_shift=tl_centripetal_shift,
+ br_centripetal_shift=br_centripetal_shift,
+ img_meta=img_meta,
+ k=self.test_cfg.corner_topk,
+ kernel=self.test_cfg.local_maximum_kernel,
+ distance_threshold=self.test_cfg.distance_threshold)
+
+ if rescale:
+ batch_bboxes /= batch_bboxes.new_tensor(img_meta['scale_factor'])
+
+ bboxes = batch_bboxes.view([-1, 4])
+ scores = batch_scores.view([-1, 1])
+ clses = batch_clses.view([-1, 1])
+
+ idx = scores.argsort(dim=0, descending=True)
+ bboxes = bboxes[idx].view([-1, 4])
+ scores = scores[idx].view(-1)
+ clses = clses[idx].view(-1)
+
+ detections = torch.cat([bboxes, scores.unsqueeze(-1)], -1)
+ keepinds = (detections[:, -1] > -0.1)
+ detections = detections[keepinds]
+ labels = clses[keepinds]
+
+ if with_nms:
+ detections, labels = self._bboxes_nms(detections, labels,
+ self.test_cfg)
+
+ return detections, labels
+
+ def _bboxes_nms(self, bboxes, labels, cfg):
+ if labels.numel() == 0:
+ return bboxes, labels
+
+ if 'nms_cfg' in cfg:
+ warning.warn('nms_cfg in test_cfg will be deprecated. '
+ 'Please rename it as nms')
+ if 'nms' not in cfg:
+ cfg.nms = cfg.nms_cfg
+
+ out_bboxes, keep = batched_nms(bboxes[:, :4], bboxes[:, -1], labels,
+ cfg.nms)
+ out_labels = labels[keep]
+
+ if len(out_bboxes) > 0:
+ idx = torch.argsort(out_bboxes[:, -1], descending=True)
+ idx = idx[:cfg.max_per_img]
+ out_bboxes = out_bboxes[idx]
+ out_labels = out_labels[idx]
+
+ return out_bboxes, out_labels
+
+ def _gather_feat(self, feat, ind, mask=None):
+ """Gather feature according to index.
+
+ Args:
+ feat (Tensor): Target feature map.
+ ind (Tensor): Target coord index.
+ mask (Tensor | None): Mask of featuremap. Default: None.
+
+ Returns:
+ feat (Tensor): Gathered feature.
+ """
+ dim = feat.size(2)
+ ind = ind.unsqueeze(2).repeat(1, 1, dim)
+ feat = feat.gather(1, ind)
+ if mask is not None:
+ mask = mask.unsqueeze(2).expand_as(feat)
+ feat = feat[mask]
+ feat = feat.view(-1, dim)
+ return feat
+
+ def _local_maximum(self, heat, kernel=3):
+ """Extract local maximum pixel with given kernel.
+
+ Args:
+ heat (Tensor): Target heatmap.
+ kernel (int): Kernel size of max pooling. Default: 3.
+
+ Returns:
+ heat (Tensor): A heatmap where local maximum pixels maintain its
+ own value and other positions are 0.
+ """
+ pad = (kernel - 1) // 2
+ hmax = F.max_pool2d(heat, kernel, stride=1, padding=pad)
+ keep = (hmax == heat).float()
+ return heat * keep
+
+ def _transpose_and_gather_feat(self, feat, ind):
+ """Transpose and gather feature according to index.
+
+ Args:
+ feat (Tensor): Target feature map.
+ ind (Tensor): Target coord index.
+
+ Returns:
+ feat (Tensor): Transposed and gathered feature.
+ """
+ feat = feat.permute(0, 2, 3, 1).contiguous()
+ feat = feat.view(feat.size(0), -1, feat.size(3))
+ feat = self._gather_feat(feat, ind)
+ return feat
+
+ def _topk(self, scores, k=20):
+ """Get top k positions from heatmap.
+
+ Args:
+ scores (Tensor): Target heatmap with shape
+ [batch, num_classes, height, width].
+ k (int): Target number. Default: 20.
+
+ Returns:
+ tuple[torch.Tensor]: Scores, indexes, categories and coords of
+ topk keypoint. Containing following Tensors:
+
+ - topk_scores (Tensor): Max scores of each topk keypoint.
+ - topk_inds (Tensor): Indexes of each topk keypoint.
+ - topk_clses (Tensor): Categories of each topk keypoint.
+ - topk_ys (Tensor): Y-coord of each topk keypoint.
+ - topk_xs (Tensor): X-coord of each topk keypoint.
+ """
+ batch, _, height, width = scores.size()
+ topk_scores, topk_inds = torch.topk(scores.view(batch, -1), k)
+ topk_clses = topk_inds // (height * width)
+ topk_inds = topk_inds % (height * width)
+ topk_ys = topk_inds // width
+ topk_xs = (topk_inds % width).int().float()
+ return topk_scores, topk_inds, topk_clses, topk_ys, topk_xs
+
+ def decode_heatmap(self,
+ tl_heat,
+ br_heat,
+ tl_off,
+ br_off,
+ tl_emb=None,
+ br_emb=None,
+ tl_centripetal_shift=None,
+ br_centripetal_shift=None,
+ img_meta=None,
+ k=100,
+ kernel=3,
+ distance_threshold=0.5,
+ num_dets=1000):
+ """Transform outputs for a single batch item into raw bbox predictions.
+
+ Args:
+ tl_heat (Tensor): Top-left corner heatmap for current level with
+ shape (N, num_classes, H, W).
+ br_heat (Tensor): Bottom-right corner heatmap for current level
+ with shape (N, num_classes, H, W).
+ tl_off (Tensor): Top-left corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ br_off (Tensor): Bottom-right corner offset for current level with
+ shape (N, corner_offset_channels, H, W).
+ tl_emb (Tensor | None): Top-left corner embedding for current
+ level with shape (N, corner_emb_channels, H, W).
+ br_emb (Tensor | None): Bottom-right corner embedding for current
+ level with shape (N, corner_emb_channels, H, W).
+ tl_centripetal_shift (Tensor | None): Top-left centripetal shift
+ for current level with shape (N, 2, H, W).
+ br_centripetal_shift (Tensor | None): Bottom-right centripetal
+ shift for current level with shape (N, 2, H, W).
+ img_meta (dict): Meta information of current image, e.g.,
+ image size, scaling factor, etc.
+ k (int): Get top k corner keypoints from heatmap.
+ kernel (int): Max pooling kernel for extract local maximum pixels.
+ distance_threshold (float): Distance threshold. Top-left and
+ bottom-right corner keypoints with feature distance less than
+ the threshold will be regarded as keypoints from same object.
+ num_dets (int): Num of raw boxes before doing nms.
+
+ Returns:
+ tuple[torch.Tensor]: Decoded output of CornerHead, containing the
+ following Tensors:
+
+ - bboxes (Tensor): Coords of each box.
+ - scores (Tensor): Scores of each box.
+ - clses (Tensor): Categories of each box.
+ """
+ with_embedding = tl_emb is not None and br_emb is not None
+ with_centripetal_shift = (
+ tl_centripetal_shift is not None
+ and br_centripetal_shift is not None)
+ assert with_embedding + with_centripetal_shift == 1
+ batch, _, height, width = tl_heat.size()
+ inp_h, inp_w, _ = img_meta['pad_shape']
+
+ # perform nms on heatmaps
+ tl_heat = self._local_maximum(tl_heat, kernel=kernel)
+ br_heat = self._local_maximum(br_heat, kernel=kernel)
+
+ tl_scores, tl_inds, tl_clses, tl_ys, tl_xs = self._topk(tl_heat, k=k)
+ br_scores, br_inds, br_clses, br_ys, br_xs = self._topk(br_heat, k=k)
+
+ # We use repeat instead of expand here because expand is a
+ # shallow-copy function. Thus it could cause unexpected testing result
+ # sometimes. Using expand will decrease about 10% mAP during testing
+ # compared to repeat.
+ tl_ys = tl_ys.view(batch, k, 1).repeat(1, 1, k)
+ tl_xs = tl_xs.view(batch, k, 1).repeat(1, 1, k)
+ br_ys = br_ys.view(batch, 1, k).repeat(1, k, 1)
+ br_xs = br_xs.view(batch, 1, k).repeat(1, k, 1)
+
+ tl_off = self._transpose_and_gather_feat(tl_off, tl_inds)
+ tl_off = tl_off.view(batch, k, 1, 2)
+ br_off = self._transpose_and_gather_feat(br_off, br_inds)
+ br_off = br_off.view(batch, 1, k, 2)
+
+ tl_xs = tl_xs + tl_off[..., 0]
+ tl_ys = tl_ys + tl_off[..., 1]
+ br_xs = br_xs + br_off[..., 0]
+ br_ys = br_ys + br_off[..., 1]
+
+ if with_centripetal_shift:
+ tl_centripetal_shift = self._transpose_and_gather_feat(
+ tl_centripetal_shift, tl_inds).view(batch, k, 1, 2).exp()
+ br_centripetal_shift = self._transpose_and_gather_feat(
+ br_centripetal_shift, br_inds).view(batch, 1, k, 2).exp()
+
+ tl_ctxs = tl_xs + tl_centripetal_shift[..., 0]
+ tl_ctys = tl_ys + tl_centripetal_shift[..., 1]
+ br_ctxs = br_xs - br_centripetal_shift[..., 0]
+ br_ctys = br_ys - br_centripetal_shift[..., 1]
+
+ # all possible boxes based on top k corners (ignoring class)
+ tl_xs *= (inp_w / width)
+ tl_ys *= (inp_h / height)
+ br_xs *= (inp_w / width)
+ br_ys *= (inp_h / height)
+
+ if with_centripetal_shift:
+ tl_ctxs *= (inp_w / width)
+ tl_ctys *= (inp_h / height)
+ br_ctxs *= (inp_w / width)
+ br_ctys *= (inp_h / height)
+
+ x_off = img_meta['border'][2]
+ y_off = img_meta['border'][0]
+
+ tl_xs -= x_off
+ tl_ys -= y_off
+ br_xs -= x_off
+ br_ys -= y_off
+
+ tl_xs *= tl_xs.gt(0.0).type_as(tl_xs)
+ tl_ys *= tl_ys.gt(0.0).type_as(tl_ys)
+ br_xs *= br_xs.gt(0.0).type_as(br_xs)
+ br_ys *= br_ys.gt(0.0).type_as(br_ys)
+
+ bboxes = torch.stack((tl_xs, tl_ys, br_xs, br_ys), dim=3)
+ area_bboxes = ((br_xs - tl_xs) * (br_ys - tl_ys)).abs()
+
+ if with_centripetal_shift:
+ tl_ctxs -= x_off
+ tl_ctys -= y_off
+ br_ctxs -= x_off
+ br_ctys -= y_off
+
+ tl_ctxs *= tl_ctxs.gt(0.0).type_as(tl_ctxs)
+ tl_ctys *= tl_ctys.gt(0.0).type_as(tl_ctys)
+ br_ctxs *= br_ctxs.gt(0.0).type_as(br_ctxs)
+ br_ctys *= br_ctys.gt(0.0).type_as(br_ctys)
+
+ ct_bboxes = torch.stack((tl_ctxs, tl_ctys, br_ctxs, br_ctys),
+ dim=3)
+ area_ct_bboxes = ((br_ctxs - tl_ctxs) * (br_ctys - tl_ctys)).abs()
+
+ rcentral = torch.zeros_like(ct_bboxes)
+ # magic nums from paper section 4.1
+ mu = torch.ones_like(area_bboxes) / 2.4
+ mu[area_bboxes > 3500] = 1 / 2.1 # large bbox have smaller mu
+
+ bboxes_center_x = (bboxes[..., 0] + bboxes[..., 2]) / 2
+ bboxes_center_y = (bboxes[..., 1] + bboxes[..., 3]) / 2
+ rcentral[..., 0] = bboxes_center_x - mu * (bboxes[..., 2] -
+ bboxes[..., 0]) / 2
+ rcentral[..., 1] = bboxes_center_y - mu * (bboxes[..., 3] -
+ bboxes[..., 1]) / 2
+ rcentral[..., 2] = bboxes_center_x + mu * (bboxes[..., 2] -
+ bboxes[..., 0]) / 2
+ rcentral[..., 3] = bboxes_center_y + mu * (bboxes[..., 3] -
+ bboxes[..., 1]) / 2
+ area_rcentral = ((rcentral[..., 2] - rcentral[..., 0]) *
+ (rcentral[..., 3] - rcentral[..., 1])).abs()
+ dists = area_ct_bboxes / area_rcentral
+
+ tl_ctx_inds = (ct_bboxes[..., 0] <= rcentral[..., 0]) | (
+ ct_bboxes[..., 0] >= rcentral[..., 2])
+ tl_cty_inds = (ct_bboxes[..., 1] <= rcentral[..., 1]) | (
+ ct_bboxes[..., 1] >= rcentral[..., 3])
+ br_ctx_inds = (ct_bboxes[..., 2] <= rcentral[..., 0]) | (
+ ct_bboxes[..., 2] >= rcentral[..., 2])
+ br_cty_inds = (ct_bboxes[..., 3] <= rcentral[..., 1]) | (
+ ct_bboxes[..., 3] >= rcentral[..., 3])
+
+ if with_embedding:
+ tl_emb = self._transpose_and_gather_feat(tl_emb, tl_inds)
+ tl_emb = tl_emb.view(batch, k, 1)
+ br_emb = self._transpose_and_gather_feat(br_emb, br_inds)
+ br_emb = br_emb.view(batch, 1, k)
+ dists = torch.abs(tl_emb - br_emb)
+
+ tl_scores = tl_scores.view(batch, k, 1).repeat(1, 1, k)
+ br_scores = br_scores.view(batch, 1, k).repeat(1, k, 1)
+
+ scores = (tl_scores + br_scores) / 2 # scores for all possible boxes
+
+ # tl and br should have same class
+ tl_clses = tl_clses.view(batch, k, 1).repeat(1, 1, k)
+ br_clses = br_clses.view(batch, 1, k).repeat(1, k, 1)
+ cls_inds = (tl_clses != br_clses)
+
+ # reject boxes based on distances
+ dist_inds = dists > distance_threshold
+
+ # reject boxes based on widths and heights
+ width_inds = (br_xs <= tl_xs)
+ height_inds = (br_ys <= tl_ys)
+
+ scores[cls_inds] = -1
+ scores[width_inds] = -1
+ scores[height_inds] = -1
+ scores[dist_inds] = -1
+ if with_centripetal_shift:
+ scores[tl_ctx_inds] = -1
+ scores[tl_cty_inds] = -1
+ scores[br_ctx_inds] = -1
+ scores[br_cty_inds] = -1
+
+ scores = scores.view(batch, -1)
+ scores, inds = torch.topk(scores, num_dets)
+ scores = scores.unsqueeze(2)
+
+ bboxes = bboxes.view(batch, -1, 4)
+ bboxes = self._gather_feat(bboxes, inds)
+
+ clses = tl_clses.contiguous().view(batch, -1, 1)
+ clses = self._gather_feat(clses, inds).float()
+
+ return bboxes, scores, clses
diff --git a/mmdet/models/dense_heads/dense_test_mixins.py b/mmdet/models/dense_heads/dense_test_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd81364dec90e97c30a6e2220a5e0fe96373c5bd
--- /dev/null
+++ b/mmdet/models/dense_heads/dense_test_mixins.py
@@ -0,0 +1,100 @@
+from inspect import signature
+
+import torch
+
+from mmdet.core import bbox2result, bbox_mapping_back, multiclass_nms
+
+
+class BBoxTestMixin(object):
+ """Mixin class for test time augmentation of bboxes."""
+
+ def merge_aug_bboxes(self, aug_bboxes, aug_scores, img_metas):
+ """Merge augmented detection bboxes and scores.
+
+ Args:
+ aug_bboxes (list[Tensor]): shape (n, 4*#class)
+ aug_scores (list[Tensor] or None): shape (n, #class)
+ img_shapes (list[Tensor]): shape (3, ).
+
+ Returns:
+ tuple: (bboxes, scores)
+ """
+ recovered_bboxes = []
+ for bboxes, img_info in zip(aug_bboxes, img_metas):
+ img_shape = img_info[0]['img_shape']
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ flip_direction = img_info[0]['flip_direction']
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip,
+ flip_direction)
+ recovered_bboxes.append(bboxes)
+ bboxes = torch.cat(recovered_bboxes, dim=0)
+ if aug_scores is None:
+ return bboxes
+ else:
+ scores = torch.cat(aug_scores, dim=0)
+ return bboxes, scores
+
+ def aug_test_bboxes(self, feats, img_metas, rescale=False):
+ """Test det bboxes with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ # check with_nms argument
+ gb_sig = signature(self.get_bboxes)
+ gb_args = [p.name for p in gb_sig.parameters.values()]
+ if hasattr(self, '_get_bboxes'):
+ gbs_sig = signature(self._get_bboxes)
+ else:
+ gbs_sig = signature(self._get_bboxes_single)
+ gbs_args = [p.name for p in gbs_sig.parameters.values()]
+ assert ('with_nms' in gb_args) and ('with_nms' in gbs_args), \
+ f'{self.__class__.__name__}' \
+ ' does not support test-time augmentation'
+
+ aug_bboxes = []
+ aug_scores = []
+ aug_factors = [] # score_factors for NMS
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ outs = self.forward(x)
+ bbox_inputs = outs + (img_meta, self.test_cfg, False, False)
+ bbox_outputs = self.get_bboxes(*bbox_inputs)[0]
+ aug_bboxes.append(bbox_outputs[0])
+ aug_scores.append(bbox_outputs[1])
+ # bbox_outputs of some detectors (e.g., ATSS, FCOS, YOLOv3)
+ # contains additional element to adjust scores before NMS
+ if len(bbox_outputs) >= 3:
+ aug_factors.append(bbox_outputs[2])
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = self.merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas)
+ merged_factors = torch.cat(aug_factors, dim=0) if aug_factors else None
+ det_bboxes, det_labels = multiclass_nms(
+ merged_bboxes,
+ merged_scores,
+ self.test_cfg.score_thr,
+ self.test_cfg.nms,
+ self.test_cfg.max_per_img,
+ score_factors=merged_factors)
+
+ if rescale:
+ _det_bboxes = det_bboxes
+ else:
+ _det_bboxes = det_bboxes.clone()
+ _det_bboxes[:, :4] *= det_bboxes.new_tensor(
+ img_metas[0][0]['scale_factor'])
+ bbox_results = bbox2result(_det_bboxes, det_labels, self.num_classes)
+ return bbox_results
diff --git a/mmdet/models/dense_heads/embedding_rpn_head.py b/mmdet/models/dense_heads/embedding_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..200ce8d20c5503f98c5c21f30bb9d00437e25f34
--- /dev/null
+++ b/mmdet/models/dense_heads/embedding_rpn_head.py
@@ -0,0 +1,100 @@
+import torch
+import torch.nn as nn
+
+from mmdet.models.builder import HEADS
+from ...core import bbox_cxcywh_to_xyxy
+
+
+@HEADS.register_module()
+class EmbeddingRPNHead(nn.Module):
+ """RPNHead in the `Sparse R-CNN `_ .
+
+ Unlike traditional RPNHead, this module does not need FPN input, but just
+ decode `init_proposal_bboxes` and expand the first dimension of
+ `init_proposal_bboxes` and `init_proposal_features` to the batch_size.
+
+ Args:
+ num_proposals (int): Number of init_proposals. Default 100.
+ proposal_feature_channel (int): Channel number of
+ init_proposal_feature. Defaults to 256.
+ """
+
+ def __init__(self,
+ num_proposals=100,
+ proposal_feature_channel=256,
+ **kwargs):
+ super(EmbeddingRPNHead, self).__init__()
+ self.num_proposals = num_proposals
+ self.proposal_feature_channel = proposal_feature_channel
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize a sparse set of proposal boxes and proposal features."""
+ self.init_proposal_bboxes = nn.Embedding(self.num_proposals, 4)
+ self.init_proposal_features = nn.Embedding(
+ self.num_proposals, self.proposal_feature_channel)
+
+ def init_weights(self):
+ """Initialize the init_proposal_bboxes as normalized.
+
+ [c_x, c_y, w, h], and we initialize it to the size of the entire
+ image.
+ """
+ nn.init.constant_(self.init_proposal_bboxes.weight[:, :2], 0.5)
+ nn.init.constant_(self.init_proposal_bboxes.weight[:, 2:], 1)
+
+ def _decode_init_proposals(self, imgs, img_metas):
+ """Decode init_proposal_bboxes according to the size of images and
+ expand dimension of init_proposal_features to batch_size.
+
+ Args:
+ imgs (list[Tensor]): List of FPN features.
+ img_metas (list[dict]): List of meta-information of
+ images. Need the img_shape to decode the init_proposals.
+
+ Returns:
+ Tuple(Tensor):
+
+ - proposals (Tensor): Decoded proposal bboxes,
+ has shape (batch_size, num_proposals, 4).
+ - init_proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel).
+ - imgs_whwh (Tensor): Tensor with shape
+ (batch_size, 4), the dimension means
+ [img_width, img_height, img_width, img_height].
+ """
+ proposals = self.init_proposal_bboxes.weight.clone()
+ proposals = bbox_cxcywh_to_xyxy(proposals)
+ num_imgs = len(imgs[0])
+ imgs_whwh = []
+ for meta in img_metas:
+ h, w, _ = meta['img_shape']
+ imgs_whwh.append(imgs[0].new_tensor([[w, h, w, h]]))
+ imgs_whwh = torch.cat(imgs_whwh, dim=0)
+ imgs_whwh = imgs_whwh[:, None, :]
+
+ # imgs_whwh has shape (batch_size, 1, 4)
+ # The shape of proposals change from (num_proposals, 4)
+ # to (batch_size ,num_proposals, 4)
+ proposals = proposals * imgs_whwh
+
+ init_proposal_features = self.init_proposal_features.weight.clone()
+ init_proposal_features = init_proposal_features[None].expand(
+ num_imgs, *init_proposal_features.size())
+ return proposals, init_proposal_features, imgs_whwh
+
+ def forward_dummy(self, img, img_metas):
+ """Dummy forward function.
+
+ Used in flops calculation.
+ """
+ return self._decode_init_proposals(img, img_metas)
+
+ def forward_train(self, img, img_metas):
+ """Forward function in training stage."""
+ return self._decode_init_proposals(img, img_metas)
+
+ def simple_test_rpn(self, img, img_metas):
+ """Forward function in testing stage."""
+ return self._decode_init_proposals(img, img_metas)
diff --git a/mmdet/models/dense_heads/fcos_head.py b/mmdet/models/dense_heads/fcos_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..905a703507f279ac8d34cff23c99af33c0d5f973
--- /dev/null
+++ b/mmdet/models/dense_heads/fcos_head.py
@@ -0,0 +1,629 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Scale, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import distance2bbox, multi_apply, multiclass_nms, reduce_mean
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+INF = 1e8
+
+
+@HEADS.register_module()
+class FCOSHead(AnchorFreeHead):
+ """Anchor-free head used in `FCOS `_.
+
+ The FCOS head does not use anchor boxes. Instead bounding boxes are
+ predicted at each pixel and a centerness measure is used to suppress
+ low-quality predictions.
+ Here norm_on_bbox, centerness_on_reg, dcn_on_last_conv are training
+ tricks used in official repo, which will bring remarkable mAP gains
+ of up to 4.9. Please see https://github.com/tianzhi0549/FCOS for
+ more detail.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ strides (list[int] | list[tuple[int, int]]): Strides of points
+ in multiple feature levels. Default: (4, 8, 16, 32, 64).
+ regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
+ level points.
+ center_sampling (bool): If true, use center sampling. Default: False.
+ center_sample_radius (float): Radius of center sampling. Default: 1.5.
+ norm_on_bbox (bool): If true, normalize the regression targets
+ with FPN strides. Default: False.
+ centerness_on_reg (bool): If true, position centerness on the
+ regress branch. Please refer to https://github.com/tianzhi0549/FCOS/issues/89#issuecomment-516877042.
+ Default: False.
+ conv_bias (bool | str): If specified as `auto`, it will be decided by the
+ norm_cfg. Bias of conv will be set as True if `norm_cfg` is None, otherwise
+ False. Default: "auto".
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ loss_centerness (dict): Config of centerness loss.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32, requires_grad=True).
+
+ Example:
+ >>> self = FCOSHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_score, bbox_pred, centerness = self.forward(feats)
+ >>> assert len(cls_score) == len(self.scales)
+ """ # noqa: E501
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
+ (512, INF)),
+ center_sampling=False,
+ center_sample_radius=1.5,
+ norm_on_bbox=False,
+ centerness_on_reg=False,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox=dict(type='IoULoss', loss_weight=1.0),
+ loss_centerness=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ **kwargs):
+ self.regress_ranges = regress_ranges
+ self.center_sampling = center_sampling
+ self.center_sample_radius = center_sample_radius
+ self.norm_on_bbox = norm_on_bbox
+ self.centerness_on_reg = centerness_on_reg
+ super().__init__(
+ num_classes,
+ in_channels,
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ norm_cfg=norm_cfg,
+ **kwargs)
+ self.loss_centerness = build_loss(loss_centerness)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ super()._init_layers()
+ self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ super().init_weights()
+ normal_init(self.conv_centerness, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Box scores for each scale level, \
+ each is a 4D-tensor, the channel number is \
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each \
+ scale level, each is a 4D-tensor, the channel number is \
+ num_points * 4.
+ centernesses (list[Tensor]): centerness for each scale level, \
+ each is a 4D-tensor, the channel number is num_points * 1.
+ """
+ return multi_apply(self.forward_single, feats, self.scales,
+ self.strides)
+
+ def forward_single(self, x, scale, stride):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+ stride (int): The corresponding stride for feature maps, only
+ used to normalize the bbox prediction when self.norm_on_bbox
+ is True.
+
+ Returns:
+ tuple: scores for each class, bbox predictions and centerness \
+ predictions of input feature maps.
+ """
+ cls_score, bbox_pred, cls_feat, reg_feat = super().forward_single(x)
+ if self.centerness_on_reg:
+ centerness = self.conv_centerness(reg_feat)
+ else:
+ centerness = self.conv_centerness(cls_feat)
+ # scale the bbox_pred of different level
+ # float to avoid overflow when enabling FP16
+ bbox_pred = scale(bbox_pred).float()
+ if self.norm_on_bbox:
+ bbox_pred = F.relu(bbox_pred)
+ if not self.training:
+ bbox_pred *= stride
+ else:
+ bbox_pred = bbox_pred.exp()
+ return cls_score, bbox_pred, centerness
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level,
+ each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ centernesses (list[Tensor]): centerness for each scale level, each
+ is a 4D-tensor, the channel number is num_points * 1.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(centernesses)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
+ bbox_preds[0].device)
+ labels, bbox_targets = self.get_targets(all_level_points, gt_bboxes,
+ gt_labels)
+
+ num_imgs = cls_scores[0].size(0)
+ # flatten cls_scores, bbox_preds and centerness
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_centerness = [
+ centerness.permute(0, 2, 3, 1).reshape(-1)
+ for centerness in centernesses
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_centerness = torch.cat(flatten_centerness)
+ flatten_labels = torch.cat(labels)
+ flatten_bbox_targets = torch.cat(bbox_targets)
+ # repeat points to align with bbox_preds
+ flatten_points = torch.cat(
+ [points.repeat(num_imgs, 1) for points in all_level_points])
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((flatten_labels >= 0)
+ & (flatten_labels < bg_class_ind)).nonzero().reshape(-1)
+ num_pos = torch.tensor(
+ len(pos_inds), dtype=torch.float, device=bbox_preds[0].device)
+ num_pos = max(reduce_mean(num_pos), 1.0)
+ loss_cls = self.loss_cls(
+ flatten_cls_scores, flatten_labels, avg_factor=num_pos)
+
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_centerness = flatten_centerness[pos_inds]
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_centerness_targets = self.centerness_target(pos_bbox_targets)
+ pos_points = flatten_points[pos_inds]
+ pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
+ pos_decoded_target_preds = distance2bbox(pos_points,
+ pos_bbox_targets)
+ # centerness weighted iou loss
+ centerness_denorm = max(
+ reduce_mean(pos_centerness_targets.sum().detach()), 1e-6)
+ loss_bbox = self.loss_bbox(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds,
+ weight=pos_centerness_targets,
+ avg_factor=centerness_denorm)
+ loss_centerness = self.loss_centerness(
+ pos_centerness, pos_centerness_targets, avg_factor=num_pos)
+ else:
+ loss_bbox = pos_bbox_preds.sum()
+ loss_centerness = pos_centerness.sum()
+
+ return dict(
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ loss_centerness=loss_centerness)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'centernesses'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ centernesses (list[Tensor]): Centerness for each scale level with
+ shape (N, num_points * 1, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
+ bbox_preds[0].device)
+
+ cls_score_list = [cls_scores[i].detach() for i in range(num_levels)]
+ bbox_pred_list = [bbox_preds[i].detach() for i in range(num_levels)]
+ centerness_pred_list = [
+ centernesses[i].detach() for i in range(num_levels)
+ ]
+ if torch.onnx.is_in_onnx_export():
+ assert len(
+ img_metas
+ ) == 1, 'Only support one input image while in exporting to ONNX'
+ img_shapes = img_metas[0]['img_shape_for_onnx']
+ else:
+ img_shapes = [
+ img_metas[i]['img_shape']
+ for i in range(cls_scores[0].shape[0])
+ ]
+ scale_factors = [
+ img_metas[i]['scale_factor'] for i in range(cls_scores[0].shape[0])
+ ]
+ result_list = self._get_bboxes(cls_score_list, bbox_pred_list,
+ centerness_pred_list, mlvl_points,
+ img_shapes, scale_factors, cfg, rescale,
+ with_nms)
+ return result_list
+
+ def _get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ centernesses,
+ mlvl_points,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for a single scale level
+ with shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for a single scale
+ level with shape (N, num_points * 4, H, W).
+ centernesses (list[Tensor]): Centerness for a single scale level
+ with shape (N, num_points * 4, H, W).
+ mlvl_points (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_points, 4).
+ img_shapes (list[tuple[int]]): Shape of the input image,
+ list[(height, width, 3)].
+ scale_factors (list[ndarray]): Scale factor of the image arrange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ tuple(Tensor):
+ det_bboxes (Tensor): BBox predictions in shape (n, 5), where
+ the first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
+ between 0 and 1.
+ det_labels (Tensor): A (n,) tensor where each item is the
+ predicted class label of the corresponding box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
+ device = cls_scores[0].device
+ batch_size = cls_scores[0].shape[0]
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_centerness = []
+ for cls_score, bbox_pred, centerness, points in zip(
+ cls_scores, bbox_preds, centernesses, mlvl_points):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ scores = cls_score.permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.cls_out_channels).sigmoid()
+ centerness = centerness.permute(0, 2, 3,
+ 1).reshape(batch_size,
+ -1).sigmoid()
+
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(batch_size, -1, 4)
+ # Always keep topk op for dynamic input in onnx
+ if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
+ or scores.shape[-2] > nms_pre_tensor):
+ from torch import _shape_as_tensor
+ # keep shape as tensor and get k
+ num_anchor = _shape_as_tensor(scores)[-2].to(device)
+ nms_pre = torch.where(nms_pre_tensor < num_anchor,
+ nms_pre_tensor, num_anchor)
+
+ max_scores, _ = (scores * centerness[..., None]).max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ points = points[topk_inds, :]
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+ centerness = centerness[batch_inds, topk_inds]
+
+ bboxes = distance2bbox(points, bbox_pred, max_shape=img_shapes)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_centerness.append(centerness)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ batch_mlvl_centerness = torch.cat(mlvl_centerness, dim=1)
+
+ # Set max number of box to be feed into nms in deployment
+ deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
+ if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
+ batch_mlvl_scores, _ = (
+ batch_mlvl_scores *
+ batch_mlvl_centerness.unsqueeze(2).expand_as(batch_mlvl_scores)
+ ).max(-1)
+ _, topk_inds = batch_mlvl_scores.topk(deploy_nms_pre)
+ batch_inds = torch.arange(batch_mlvl_scores.shape[0]).view(
+ -1, 1).expand_as(topk_inds)
+ batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :]
+ batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :]
+ batch_mlvl_centerness = batch_mlvl_centerness[batch_inds,
+ topk_inds]
+
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1], 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+
+ if with_nms:
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores,
+ mlvl_centerness) in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_centerness):
+ det_bbox, det_label = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=mlvl_centerness)
+ det_results.append(tuple([det_bbox, det_label]))
+ else:
+ det_results = [
+ tuple(mlvl_bs)
+ for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_centerness)
+ ]
+ return det_results
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points according to feature map sizes."""
+ y, x = super()._get_points_single(featmap_size, stride, dtype, device)
+ points = torch.stack((x.reshape(-1) * stride, y.reshape(-1) * stride),
+ dim=-1) + stride // 2
+ return points
+
+ def get_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute regression, classification and centerness targets for points
+ in multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+
+ Returns:
+ tuple:
+ concat_lvl_labels (list[Tensor]): Labels of each level. \
+ concat_lvl_bbox_targets (list[Tensor]): BBox targets of each \
+ level.
+ """
+ assert len(points) == len(self.regress_ranges)
+ num_levels = len(points)
+ # expand regress ranges to align with points
+ expanded_regress_ranges = [
+ points[i].new_tensor(self.regress_ranges[i])[None].expand_as(
+ points[i]) for i in range(num_levels)
+ ]
+ # concat all levels points and regress ranges
+ concat_regress_ranges = torch.cat(expanded_regress_ranges, dim=0)
+ concat_points = torch.cat(points, dim=0)
+
+ # the number of points per img, per lvl
+ num_points = [center.size(0) for center in points]
+
+ # get labels and bbox_targets of each image
+ labels_list, bbox_targets_list = multi_apply(
+ self._get_target_single,
+ gt_bboxes_list,
+ gt_labels_list,
+ points=concat_points,
+ regress_ranges=concat_regress_ranges,
+ num_points_per_lvl=num_points)
+
+ # split to per img, per level
+ labels_list = [labels.split(num_points, 0) for labels in labels_list]
+ bbox_targets_list = [
+ bbox_targets.split(num_points, 0)
+ for bbox_targets in bbox_targets_list
+ ]
+
+ # concat per level image
+ concat_lvl_labels = []
+ concat_lvl_bbox_targets = []
+ for i in range(num_levels):
+ concat_lvl_labels.append(
+ torch.cat([labels[i] for labels in labels_list]))
+ bbox_targets = torch.cat(
+ [bbox_targets[i] for bbox_targets in bbox_targets_list])
+ if self.norm_on_bbox:
+ bbox_targets = bbox_targets / self.strides[i]
+ concat_lvl_bbox_targets.append(bbox_targets)
+ return concat_lvl_labels, concat_lvl_bbox_targets
+
+ def _get_target_single(self, gt_bboxes, gt_labels, points, regress_ranges,
+ num_points_per_lvl):
+ """Compute regression and classification targets for a single image."""
+ num_points = points.size(0)
+ num_gts = gt_labels.size(0)
+ if num_gts == 0:
+ return gt_labels.new_full((num_points,), self.num_classes), \
+ gt_bboxes.new_zeros((num_points, 4))
+
+ areas = (gt_bboxes[:, 2] - gt_bboxes[:, 0]) * (
+ gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ # TODO: figure out why these two are different
+ # areas = areas[None].expand(num_points, num_gts)
+ areas = areas[None].repeat(num_points, 1)
+ regress_ranges = regress_ranges[:, None, :].expand(
+ num_points, num_gts, 2)
+ gt_bboxes = gt_bboxes[None].expand(num_points, num_gts, 4)
+ xs, ys = points[:, 0], points[:, 1]
+ xs = xs[:, None].expand(num_points, num_gts)
+ ys = ys[:, None].expand(num_points, num_gts)
+
+ left = xs - gt_bboxes[..., 0]
+ right = gt_bboxes[..., 2] - xs
+ top = ys - gt_bboxes[..., 1]
+ bottom = gt_bboxes[..., 3] - ys
+ bbox_targets = torch.stack((left, top, right, bottom), -1)
+
+ if self.center_sampling:
+ # condition1: inside a `center bbox`
+ radius = self.center_sample_radius
+ center_xs = (gt_bboxes[..., 0] + gt_bboxes[..., 2]) / 2
+ center_ys = (gt_bboxes[..., 1] + gt_bboxes[..., 3]) / 2
+ center_gts = torch.zeros_like(gt_bboxes)
+ stride = center_xs.new_zeros(center_xs.shape)
+
+ # project the points on current lvl back to the `original` sizes
+ lvl_begin = 0
+ for lvl_idx, num_points_lvl in enumerate(num_points_per_lvl):
+ lvl_end = lvl_begin + num_points_lvl
+ stride[lvl_begin:lvl_end] = self.strides[lvl_idx] * radius
+ lvl_begin = lvl_end
+
+ x_mins = center_xs - stride
+ y_mins = center_ys - stride
+ x_maxs = center_xs + stride
+ y_maxs = center_ys + stride
+ center_gts[..., 0] = torch.where(x_mins > gt_bboxes[..., 0],
+ x_mins, gt_bboxes[..., 0])
+ center_gts[..., 1] = torch.where(y_mins > gt_bboxes[..., 1],
+ y_mins, gt_bboxes[..., 1])
+ center_gts[..., 2] = torch.where(x_maxs > gt_bboxes[..., 2],
+ gt_bboxes[..., 2], x_maxs)
+ center_gts[..., 3] = torch.where(y_maxs > gt_bboxes[..., 3],
+ gt_bboxes[..., 3], y_maxs)
+
+ cb_dist_left = xs - center_gts[..., 0]
+ cb_dist_right = center_gts[..., 2] - xs
+ cb_dist_top = ys - center_gts[..., 1]
+ cb_dist_bottom = center_gts[..., 3] - ys
+ center_bbox = torch.stack(
+ (cb_dist_left, cb_dist_top, cb_dist_right, cb_dist_bottom), -1)
+ inside_gt_bbox_mask = center_bbox.min(-1)[0] > 0
+ else:
+ # condition1: inside a gt bbox
+ inside_gt_bbox_mask = bbox_targets.min(-1)[0] > 0
+
+ # condition2: limit the regression range for each location
+ max_regress_distance = bbox_targets.max(-1)[0]
+ inside_regress_range = (
+ (max_regress_distance >= regress_ranges[..., 0])
+ & (max_regress_distance <= regress_ranges[..., 1]))
+
+ # if there are still more than one objects for a location,
+ # we choose the one with minimal area
+ areas[inside_gt_bbox_mask == 0] = INF
+ areas[inside_regress_range == 0] = INF
+ min_area, min_area_inds = areas.min(dim=1)
+
+ labels = gt_labels[min_area_inds]
+ labels[min_area == INF] = self.num_classes # set as BG
+ bbox_targets = bbox_targets[range(num_points), min_area_inds]
+
+ return labels, bbox_targets
+
+ def centerness_target(self, pos_bbox_targets):
+ """Compute centerness targets.
+
+ Args:
+ pos_bbox_targets (Tensor): BBox targets of positive bboxes in shape
+ (num_pos, 4)
+
+ Returns:
+ Tensor: Centerness target.
+ """
+ # only calculate pos centerness targets, otherwise there may be nan
+ left_right = pos_bbox_targets[:, [0, 2]]
+ top_bottom = pos_bbox_targets[:, [1, 3]]
+ centerness_targets = (
+ left_right.min(dim=-1)[0] / left_right.max(dim=-1)[0]) * (
+ top_bottom.min(dim=-1)[0] / top_bottom.max(dim=-1)[0])
+ return torch.sqrt(centerness_targets)
diff --git a/mmdet/models/dense_heads/fovea_head.py b/mmdet/models/dense_heads/fovea_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8ccea787cba3d092284d4a5e209adaf6521c86a
--- /dev/null
+++ b/mmdet/models/dense_heads/fovea_head.py
@@ -0,0 +1,341 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init
+from mmcv.ops import DeformConv2d
+
+from mmdet.core import multi_apply, multiclass_nms
+from ..builder import HEADS
+from .anchor_free_head import AnchorFreeHead
+
+INF = 1e8
+
+
+class FeatureAlign(nn.Module):
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deform_groups=4):
+ super(FeatureAlign, self).__init__()
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ 4, deform_groups * offset_channels, 1, bias=False)
+ self.conv_adaption = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deform_groups=deform_groups)
+ self.relu = nn.ReLU(inplace=True)
+
+ def init_weights(self):
+ normal_init(self.conv_offset, std=0.1)
+ normal_init(self.conv_adaption, std=0.01)
+
+ def forward(self, x, shape):
+ offset = self.conv_offset(shape)
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+
+@HEADS.register_module()
+class FoveaHead(AnchorFreeHead):
+ """FoveaBox: Beyond Anchor-based Object Detector
+ https://arxiv.org/abs/1904.03797
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ base_edge_list=(16, 32, 64, 128, 256),
+ scale_ranges=((8, 32), (16, 64), (32, 128), (64, 256), (128,
+ 512)),
+ sigma=0.4,
+ with_deform=False,
+ deform_groups=4,
+ **kwargs):
+ self.base_edge_list = base_edge_list
+ self.scale_ranges = scale_ranges
+ self.sigma = sigma
+ self.with_deform = with_deform
+ self.deform_groups = deform_groups
+ super().__init__(num_classes, in_channels, **kwargs)
+
+ def _init_layers(self):
+ # box branch
+ super()._init_reg_convs()
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+
+ # cls branch
+ if not self.with_deform:
+ super()._init_cls_convs()
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ else:
+ self.cls_convs = nn.ModuleList()
+ self.cls_convs.append(
+ ConvModule(
+ self.feat_channels, (self.feat_channels * 4),
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+ self.cls_convs.append(
+ ConvModule((self.feat_channels * 4), (self.feat_channels * 4),
+ 1,
+ stride=1,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.norm_cfg is None))
+ self.feature_adaption = FeatureAlign(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.conv_cls = nn.Conv2d(
+ int(self.feat_channels * 4),
+ self.cls_out_channels,
+ 3,
+ padding=1)
+
+ def init_weights(self):
+ super().init_weights()
+ if self.with_deform:
+ self.feature_adaption.init_weights()
+
+ def forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+ bbox_pred = self.conv_reg(reg_feat)
+ if self.with_deform:
+ cls_feat = self.feature_adaption(cls_feat, bbox_pred.exp())
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+ cls_score = self.conv_cls(cls_feat)
+ return cls_score, bbox_pred
+
+ def _get_points_single(self, *args, **kwargs):
+ y, x = super()._get_points_single(*args, **kwargs)
+ return y + 0.5, x + 0.5
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bbox_list,
+ gt_label_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ assert len(cls_scores) == len(bbox_preds)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
+ bbox_preds[0].device)
+ num_imgs = cls_scores[0].size(0)
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(-1, self.cls_out_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_labels, flatten_bbox_targets = self.get_targets(
+ gt_bbox_list, gt_label_list, featmap_sizes, points)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((flatten_labels >= 0)
+ & (flatten_labels < self.num_classes)).nonzero().view(-1)
+ num_pos = len(pos_inds)
+
+ loss_cls = self.loss_cls(
+ flatten_cls_scores, flatten_labels, avg_factor=num_pos + num_imgs)
+ if num_pos > 0:
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_weights = pos_bbox_targets.new_zeros(
+ pos_bbox_targets.size()) + 1.0
+ loss_bbox = self.loss_bbox(
+ pos_bbox_preds,
+ pos_bbox_targets,
+ pos_weights,
+ avg_factor=num_pos)
+ else:
+ loss_bbox = torch.tensor(
+ 0,
+ dtype=flatten_bbox_preds.dtype,
+ device=flatten_bbox_preds.device)
+ return dict(loss_cls=loss_cls, loss_bbox=loss_bbox)
+
+ def get_targets(self, gt_bbox_list, gt_label_list, featmap_sizes, points):
+ label_list, bbox_target_list = multi_apply(
+ self._get_target_single,
+ gt_bbox_list,
+ gt_label_list,
+ featmap_size_list=featmap_sizes,
+ point_list=points)
+ flatten_labels = [
+ torch.cat([
+ labels_level_img.flatten() for labels_level_img in labels_level
+ ]) for labels_level in zip(*label_list)
+ ]
+ flatten_bbox_targets = [
+ torch.cat([
+ bbox_targets_level_img.reshape(-1, 4)
+ for bbox_targets_level_img in bbox_targets_level
+ ]) for bbox_targets_level in zip(*bbox_target_list)
+ ]
+ flatten_labels = torch.cat(flatten_labels)
+ flatten_bbox_targets = torch.cat(flatten_bbox_targets)
+ return flatten_labels, flatten_bbox_targets
+
+ def _get_target_single(self,
+ gt_bboxes_raw,
+ gt_labels_raw,
+ featmap_size_list=None,
+ point_list=None):
+
+ gt_areas = torch.sqrt((gt_bboxes_raw[:, 2] - gt_bboxes_raw[:, 0]) *
+ (gt_bboxes_raw[:, 3] - gt_bboxes_raw[:, 1]))
+ label_list = []
+ bbox_target_list = []
+ # for each pyramid, find the cls and box target
+ for base_len, (lower_bound, upper_bound), stride, featmap_size, \
+ (y, x) in zip(self.base_edge_list, self.scale_ranges,
+ self.strides, featmap_size_list, point_list):
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ labels = gt_labels_raw.new_zeros(featmap_size) + self.num_classes
+ bbox_targets = gt_bboxes_raw.new(featmap_size[0], featmap_size[1],
+ 4) + 1
+ # scale assignment
+ hit_indices = ((gt_areas >= lower_bound) &
+ (gt_areas <= upper_bound)).nonzero().flatten()
+ if len(hit_indices) == 0:
+ label_list.append(labels)
+ bbox_target_list.append(torch.log(bbox_targets))
+ continue
+ _, hit_index_order = torch.sort(-gt_areas[hit_indices])
+ hit_indices = hit_indices[hit_index_order]
+ gt_bboxes = gt_bboxes_raw[hit_indices, :] / stride
+ gt_labels = gt_labels_raw[hit_indices]
+ half_w = 0.5 * (gt_bboxes[:, 2] - gt_bboxes[:, 0])
+ half_h = 0.5 * (gt_bboxes[:, 3] - gt_bboxes[:, 1])
+ # valid fovea area: left, right, top, down
+ pos_left = torch.ceil(
+ gt_bboxes[:, 0] + (1 - self.sigma) * half_w - 0.5).long().\
+ clamp(0, featmap_size[1] - 1)
+ pos_right = torch.floor(
+ gt_bboxes[:, 0] + (1 + self.sigma) * half_w - 0.5).long().\
+ clamp(0, featmap_size[1] - 1)
+ pos_top = torch.ceil(
+ gt_bboxes[:, 1] + (1 - self.sigma) * half_h - 0.5).long().\
+ clamp(0, featmap_size[0] - 1)
+ pos_down = torch.floor(
+ gt_bboxes[:, 1] + (1 + self.sigma) * half_h - 0.5).long().\
+ clamp(0, featmap_size[0] - 1)
+ for px1, py1, px2, py2, label, (gt_x1, gt_y1, gt_x2, gt_y2) in \
+ zip(pos_left, pos_top, pos_right, pos_down, gt_labels,
+ gt_bboxes_raw[hit_indices, :]):
+ labels[py1:py2 + 1, px1:px2 + 1] = label
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 0] = \
+ (stride * x[py1:py2 + 1, px1:px2 + 1] - gt_x1) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 1] = \
+ (stride * y[py1:py2 + 1, px1:px2 + 1] - gt_y1) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 2] = \
+ (gt_x2 - stride * x[py1:py2 + 1, px1:px2 + 1]) / base_len
+ bbox_targets[py1:py2 + 1, px1:px2 + 1, 3] = \
+ (gt_y2 - stride * y[py1:py2 + 1, px1:px2 + 1]) / base_len
+ bbox_targets = bbox_targets.clamp(min=1. / 16, max=16.)
+ label_list.append(labels)
+ bbox_target_list.append(torch.log(bbox_targets))
+ return label_list, bbox_target_list
+
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg=None,
+ rescale=None):
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ points = self.get_points(
+ featmap_sizes,
+ bbox_preds[0].dtype,
+ bbox_preds[0].device,
+ flatten=True)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ det_bboxes = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list, featmap_sizes,
+ points, img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(det_bboxes)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ featmap_sizes,
+ point_list,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(point_list)
+ det_bboxes = []
+ det_scores = []
+ for cls_score, bbox_pred, featmap_size, stride, base_len, (y, x) \
+ in zip(cls_scores, bbox_preds, featmap_sizes, self.strides,
+ self.base_edge_list, point_list):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ scores = cls_score.permute(1, 2, 0).reshape(
+ -1, self.cls_out_channels).sigmoid()
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).exp()
+ nms_pre = cfg.get('nms_pre', -1)
+ if (nms_pre > 0) and (scores.shape[0] > nms_pre):
+ max_scores, _ = scores.max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ y = y[topk_inds]
+ x = x[topk_inds]
+ x1 = (stride * x - base_len * bbox_pred[:, 0]).\
+ clamp(min=0, max=img_shape[1] - 1)
+ y1 = (stride * y - base_len * bbox_pred[:, 1]).\
+ clamp(min=0, max=img_shape[0] - 1)
+ x2 = (stride * x + base_len * bbox_pred[:, 2]).\
+ clamp(min=0, max=img_shape[1] - 1)
+ y2 = (stride * y + base_len * bbox_pred[:, 3]).\
+ clamp(min=0, max=img_shape[0] - 1)
+ bboxes = torch.stack([x1, y1, x2, y2], -1)
+ det_bboxes.append(bboxes)
+ det_scores.append(scores)
+ det_bboxes = torch.cat(det_bboxes)
+ if rescale:
+ det_bboxes /= det_bboxes.new_tensor(scale_factor)
+ det_scores = torch.cat(det_scores)
+ padding = det_scores.new_zeros(det_scores.shape[0], 1)
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ det_scores = torch.cat([det_scores, padding], dim=1)
+ det_bboxes, det_labels = multiclass_nms(det_bboxes, det_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/free_anchor_retina_head.py b/mmdet/models/dense_heads/free_anchor_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..79879fdc3171b8e34b606b27eb1ceb67f4473e3e
--- /dev/null
+++ b/mmdet/models/dense_heads/free_anchor_retina_head.py
@@ -0,0 +1,270 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import bbox_overlaps
+from ..builder import HEADS
+from .retina_head import RetinaHead
+
+EPS = 1e-12
+
+
+@HEADS.register_module()
+class FreeAnchorRetinaHead(RetinaHead):
+ """FreeAnchor RetinaHead used in https://arxiv.org/abs/1909.02466.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 4.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32,
+ requires_grad=True).
+ pre_anchor_topk (int): Number of boxes that be token in each bag.
+ bbox_thr (float): The threshold of the saturated linear function. It is
+ usually the same with the IoU threshold used in NMS.
+ gamma (float): Gamma parameter in focal loss.
+ alpha (float): Alpha parameter in focal loss.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ pre_anchor_topk=50,
+ bbox_thr=0.6,
+ gamma=2.0,
+ alpha=0.5,
+ **kwargs):
+ super(FreeAnchorRetinaHead,
+ self).__init__(num_classes, in_channels, stacked_convs, conv_cfg,
+ norm_cfg, **kwargs)
+
+ self.pre_anchor_topk = pre_anchor_topk
+ self.bbox_thr = bbox_thr
+ self.gamma = gamma
+ self.alpha = alpha
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == len(self.anchor_generator.base_anchors)
+
+ anchor_list, _ = self.get_anchors(featmap_sizes, img_metas)
+ anchors = [torch.cat(anchor) for anchor in anchor_list]
+
+ # concatenate each level
+ cls_scores = [
+ cls.permute(0, 2, 3,
+ 1).reshape(cls.size(0), -1, self.cls_out_channels)
+ for cls in cls_scores
+ ]
+ bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(bbox_pred.size(0), -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ cls_scores = torch.cat(cls_scores, dim=1)
+ bbox_preds = torch.cat(bbox_preds, dim=1)
+
+ cls_prob = torch.sigmoid(cls_scores)
+ box_prob = []
+ num_pos = 0
+ positive_losses = []
+ for _, (anchors_, gt_labels_, gt_bboxes_, cls_prob_,
+ bbox_preds_) in enumerate(
+ zip(anchors, gt_labels, gt_bboxes, cls_prob, bbox_preds)):
+
+ with torch.no_grad():
+ if len(gt_bboxes_) == 0:
+ image_box_prob = torch.zeros(
+ anchors_.size(0),
+ self.cls_out_channels).type_as(bbox_preds_)
+ else:
+ # box_localization: a_{j}^{loc}, shape: [j, 4]
+ pred_boxes = self.bbox_coder.decode(anchors_, bbox_preds_)
+
+ # object_box_iou: IoU_{ij}^{loc}, shape: [i, j]
+ object_box_iou = bbox_overlaps(gt_bboxes_, pred_boxes)
+
+ # object_box_prob: P{a_{j} -> b_{i}}, shape: [i, j]
+ t1 = self.bbox_thr
+ t2 = object_box_iou.max(
+ dim=1, keepdim=True).values.clamp(min=t1 + 1e-12)
+ object_box_prob = ((object_box_iou - t1) /
+ (t2 - t1)).clamp(
+ min=0, max=1)
+
+ # object_cls_box_prob: P{a_{j} -> b_{i}}, shape: [i, c, j]
+ num_obj = gt_labels_.size(0)
+ indices = torch.stack([
+ torch.arange(num_obj).type_as(gt_labels_), gt_labels_
+ ],
+ dim=0)
+ object_cls_box_prob = torch.sparse_coo_tensor(
+ indices, object_box_prob)
+
+ # image_box_iou: P{a_{j} \in A_{+}}, shape: [c, j]
+ """
+ from "start" to "end" implement:
+ image_box_iou = torch.sparse.max(object_cls_box_prob,
+ dim=0).t()
+
+ """
+ # start
+ box_cls_prob = torch.sparse.sum(
+ object_cls_box_prob, dim=0).to_dense()
+
+ indices = torch.nonzero(box_cls_prob, as_tuple=False).t_()
+ if indices.numel() == 0:
+ image_box_prob = torch.zeros(
+ anchors_.size(0),
+ self.cls_out_channels).type_as(object_box_prob)
+ else:
+ nonzero_box_prob = torch.where(
+ (gt_labels_.unsqueeze(dim=-1) == indices[0]),
+ object_box_prob[:, indices[1]],
+ torch.tensor([
+ 0
+ ]).type_as(object_box_prob)).max(dim=0).values
+
+ # upmap to shape [j, c]
+ image_box_prob = torch.sparse_coo_tensor(
+ indices.flip([0]),
+ nonzero_box_prob,
+ size=(anchors_.size(0),
+ self.cls_out_channels)).to_dense()
+ # end
+
+ box_prob.append(image_box_prob)
+
+ # construct bags for objects
+ match_quality_matrix = bbox_overlaps(gt_bboxes_, anchors_)
+ _, matched = torch.topk(
+ match_quality_matrix,
+ self.pre_anchor_topk,
+ dim=1,
+ sorted=False)
+ del match_quality_matrix
+
+ # matched_cls_prob: P_{ij}^{cls}
+ matched_cls_prob = torch.gather(
+ cls_prob_[matched], 2,
+ gt_labels_.view(-1, 1, 1).repeat(1, self.pre_anchor_topk,
+ 1)).squeeze(2)
+
+ # matched_box_prob: P_{ij}^{loc}
+ matched_anchors = anchors_[matched]
+ matched_object_targets = self.bbox_coder.encode(
+ matched_anchors,
+ gt_bboxes_.unsqueeze(dim=1).expand_as(matched_anchors))
+ loss_bbox = self.loss_bbox(
+ bbox_preds_[matched],
+ matched_object_targets,
+ reduction_override='none').sum(-1)
+ matched_box_prob = torch.exp(-loss_bbox)
+
+ # positive_losses: {-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )}
+ num_pos += len(gt_bboxes_)
+ positive_losses.append(
+ self.positive_bag_loss(matched_cls_prob, matched_box_prob))
+ positive_loss = torch.cat(positive_losses).sum() / max(1, num_pos)
+
+ # box_prob: P{a_{j} \in A_{+}}
+ box_prob = torch.stack(box_prob, dim=0)
+
+ # negative_loss:
+ # \sum_{j}{ FL((1 - P{a_{j} \in A_{+}}) * (1 - P_{j}^{bg})) } / n||B||
+ negative_loss = self.negative_bag_loss(cls_prob, box_prob).sum() / max(
+ 1, num_pos * self.pre_anchor_topk)
+
+ # avoid the absence of gradients in regression subnet
+ # when no ground-truth in a batch
+ if num_pos == 0:
+ positive_loss = bbox_preds.sum() * 0
+
+ losses = {
+ 'positive_bag_loss': positive_loss,
+ 'negative_bag_loss': negative_loss
+ }
+ return losses
+
+ def positive_bag_loss(self, matched_cls_prob, matched_box_prob):
+ """Compute positive bag loss.
+
+ :math:`-log( Mean-max(P_{ij}^{cls} * P_{ij}^{loc}) )`.
+
+ :math:`P_{ij}^{cls}`: matched_cls_prob, classification probability of matched samples.
+
+ :math:`P_{ij}^{loc}`: matched_box_prob, box probability of matched samples.
+
+ Args:
+ matched_cls_prob (Tensor): Classification probabilty of matched
+ samples in shape (num_gt, pre_anchor_topk).
+ matched_box_prob (Tensor): BBox probability of matched samples,
+ in shape (num_gt, pre_anchor_topk).
+
+ Returns:
+ Tensor: Positive bag loss in shape (num_gt,).
+ """ # noqa: E501, W605
+ # bag_prob = Mean-max(matched_prob)
+ matched_prob = matched_cls_prob * matched_box_prob
+ weight = 1 / torch.clamp(1 - matched_prob, 1e-12, None)
+ weight /= weight.sum(dim=1).unsqueeze(dim=-1)
+ bag_prob = (weight * matched_prob).sum(dim=1)
+ # positive_bag_loss = -self.alpha * log(bag_prob)
+ return self.alpha * F.binary_cross_entropy(
+ bag_prob, torch.ones_like(bag_prob), reduction='none')
+
+ def negative_bag_loss(self, cls_prob, box_prob):
+ """Compute negative bag loss.
+
+ :math:`FL((1 - P_{a_{j} \in A_{+}}) * (1 - P_{j}^{bg}))`.
+
+ :math:`P_{a_{j} \in A_{+}}`: Box_probability of matched samples.
+
+ :math:`P_{j}^{bg}`: Classification probability of negative samples.
+
+ Args:
+ cls_prob (Tensor): Classification probability, in shape
+ (num_img, num_anchors, num_classes).
+ box_prob (Tensor): Box probability, in shape
+ (num_img, num_anchors, num_classes).
+
+ Returns:
+ Tensor: Negative bag loss in shape (num_img, num_anchors, num_classes).
+ """ # noqa: E501, W605
+ prob = cls_prob * (1 - box_prob)
+ # There are some cases when neg_prob = 0.
+ # This will cause the neg_prob.log() to be inf without clamp.
+ prob = prob.clamp(min=EPS, max=1 - EPS)
+ negative_bag_loss = prob**self.gamma * F.binary_cross_entropy(
+ prob, torch.zeros_like(prob), reduction='none')
+ return (1 - self.alpha) * negative_bag_loss
diff --git a/mmdet/models/dense_heads/fsaf_head.py b/mmdet/models/dense_heads/fsaf_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7183efce28596ba106411250f508aec5995fbf60
--- /dev/null
+++ b/mmdet/models/dense_heads/fsaf_head.py
@@ -0,0 +1,422 @@
+import numpy as np
+import torch
+from mmcv.cnn import normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, images_to_levels, multi_apply,
+ unmap)
+from ..builder import HEADS
+from ..losses.accuracy import accuracy
+from ..losses.utils import weight_reduce_loss
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module()
+class FSAFHead(RetinaHead):
+ """Anchor-free head used in `FSAF `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors (num_anchors is 1 for anchor-
+ free methods)
+
+ Args:
+ *args: Same as its base class in :class:`RetinaHead`
+ score_threshold (float, optional): The score_threshold to calculate
+ positive recall. If given, prediction scores lower than this value
+ is counted as incorrect prediction. Default to None.
+ **kwargs: Same as its base class in :class:`RetinaHead`
+
+ Example:
+ >>> import torch
+ >>> self = FSAFHead(11, 7)
+ >>> x = torch.rand(1, 7, 32, 32)
+ >>> cls_score, bbox_pred = self.forward_single(x)
+ >>> # Each anchor predicts a score for each class except background
+ >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+ >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+ >>> assert cls_per_anchor == self.num_classes
+ >>> assert box_per_anchor == 4
+ """
+
+ def __init__(self, *args, score_threshold=None, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.score_threshold = score_threshold
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level.
+
+ Args:
+ x (Tensor): Feature map of a single scale level.
+
+ Returns:
+ tuple (Tensor):
+ cls_score (Tensor): Box scores for each scale level
+ Has shape (N, num_points * num_classes, H, W).
+ bbox_pred (Tensor): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ """
+ cls_score, bbox_pred = super().forward_single(x)
+ # relu: TBLR encoder only accepts positive bbox_pred
+ return cls_score, self.relu(bbox_pred)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ super(FSAFHead, self).init_weights()
+ # The positive bias in self.retina_reg conv is to prevent predicted \
+ # bbox with 0 area
+ normal_init(self.retina_reg, std=0.01, bias=0.25)
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Most of the codes are the same with the base class
+ :obj: `AnchorHead`, except that it also collects and returns
+ the matched gt index in the image (from 0 to num_gt-1). If the
+ anchor bbox is not matched to any gt, the corresponding value in
+ pos_gt_inds is -1.
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # Assign gt and sample anchors
+ anchors = flat_anchors[inside_flags.type(torch.bool), :]
+ assign_result = self.assigner.assign(
+ anchors, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros((num_valid_anchors, label_channels),
+ dtype=torch.float)
+ pos_gt_inds = anchors.new_full((num_valid_anchors, ),
+ -1,
+ dtype=torch.long)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ if len(pos_inds) > 0:
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+ else:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, both
+ # the predicted boxes and regression targets should be with
+ # absolute coordinate format.
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ # The assigned gt_index for each anchor. (0-based)
+ pos_gt_inds[pos_inds] = sampling_result.pos_assigned_gt_inds
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # shadowed_labels is a tensor composed of tuples
+ # (anchor_inds, class_label) that indicate those anchors lying in the
+ # outer region of a gt or overlapped by another gt with a smaller
+ # area.
+ #
+ # Therefore, only the shadowed labels are ignored for loss calculation.
+ # the key `shadowed_labels` is defined in :obj:`CenterRegionAssigner`
+ shadowed_labels = assign_result.get_extra_property('shadowed_labels')
+ if shadowed_labels is not None and shadowed_labels.numel():
+ if len(shadowed_labels.shape) == 2:
+ idx_, label_ = shadowed_labels[:, 0], shadowed_labels[:, 1]
+ assert (labels[idx_] != label_).all(), \
+ 'One label cannot be both positive and ignored'
+ label_weights[idx_, label_] = 0
+ else:
+ label_weights[shadowed_labels] = 0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ labels = unmap(labels, num_total_anchors, inside_flags)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+ pos_gt_inds = unmap(
+ pos_gt_inds, num_total_anchors, inside_flags, fill=-1)
+
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds, sampling_result, pos_gt_inds)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_points * 4, H, W).
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ for i in range(len(bbox_preds)): # loop over fpn level
+ # avoid 0 area of the predicted bbox
+ bbox_preds[i] = bbox_preds[i].clamp(min=1e-4)
+ # TODO: It may directly use the base-class loss function.
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+ batch_size = len(gt_bboxes)
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg,
+ pos_assigned_gt_inds_list) = cls_reg_targets
+
+ num_gts = np.array(list(map(len, gt_labels)))
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ # `pos_assigned_gt_inds_list` (length: fpn_levels) stores the assigned
+ # gt index of each anchor bbox in each fpn level.
+ cum_num_gts = list(np.cumsum(num_gts)) # length of batch_size
+ for i, assign in enumerate(pos_assigned_gt_inds_list):
+ # loop over fpn levels
+ for j in range(1, batch_size):
+ # loop over batch size
+ # Convert gt indices in each img to those in the batch
+ assign[j][assign[j] >= 0] += int(cum_num_gts[j - 1])
+ pos_assigned_gt_inds_list[i] = assign.flatten()
+ labels_list[i] = labels_list[i].flatten()
+ num_gts = sum(map(len, gt_labels)) # total number of gt in the batch
+ # The unique label index of each gt in the batch
+ label_sequence = torch.arange(num_gts, device=device)
+ # Collect the average loss of each gt in each level
+ with torch.no_grad():
+ loss_levels, = multi_apply(
+ self.collect_loss_level_single,
+ losses_cls,
+ losses_bbox,
+ pos_assigned_gt_inds_list,
+ labels_seq=label_sequence)
+ # Shape: (fpn_levels, num_gts). Loss of each gt at each fpn level
+ loss_levels = torch.stack(loss_levels, dim=0)
+ # Locate the best fpn level for loss back-propagation
+ if loss_levels.numel() == 0: # zero gt
+ argmin = loss_levels.new_empty((num_gts, ), dtype=torch.long)
+ else:
+ _, argmin = loss_levels.min(dim=0)
+
+ # Reweight the loss of each (anchor, label) pair, so that only those
+ # at the best gt level are back-propagated.
+ losses_cls, losses_bbox, pos_inds = multi_apply(
+ self.reweight_loss_single,
+ losses_cls,
+ losses_bbox,
+ pos_assigned_gt_inds_list,
+ labels_list,
+ list(range(len(losses_cls))),
+ min_levels=argmin)
+ num_pos = torch.cat(pos_inds, 0).sum().float()
+ pos_recall = self.calculate_pos_recall(cls_scores, labels_list,
+ pos_inds)
+
+ if num_pos == 0: # No gt
+ avg_factor = num_pos + float(num_total_neg)
+ else:
+ avg_factor = num_pos
+ for i in range(len(losses_cls)):
+ losses_cls[i] /= avg_factor
+ losses_bbox[i] /= avg_factor
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ num_pos=num_pos / batch_size,
+ pos_recall=pos_recall)
+
+ def calculate_pos_recall(self, cls_scores, labels_list, pos_inds):
+ """Calculate positive recall with score threshold.
+
+ Args:
+ cls_scores (list[Tensor]): Classification scores at all fpn levels.
+ Each tensor is in shape (N, num_classes * num_anchors, H, W)
+ labels_list (list[Tensor]): The label that each anchor is assigned
+ to. Shape (N * H * W * num_anchors, )
+ pos_inds (list[Tensor]): List of bool tensors indicating whether
+ the anchor is assigned to a positive label.
+ Shape (N * H * W * num_anchors, )
+
+ Returns:
+ Tensor: A single float number indicating the positive recall.
+ """
+ with torch.no_grad():
+ num_class = self.num_classes
+ scores = [
+ cls.permute(0, 2, 3, 1).reshape(-1, num_class)[pos]
+ for cls, pos in zip(cls_scores, pos_inds)
+ ]
+ labels = [
+ label.reshape(-1)[pos]
+ for label, pos in zip(labels_list, pos_inds)
+ ]
+ scores = torch.cat(scores, dim=0)
+ labels = torch.cat(labels, dim=0)
+ if self.use_sigmoid_cls:
+ scores = scores.sigmoid()
+ else:
+ scores = scores.softmax(dim=1)
+
+ return accuracy(scores, labels, thresh=self.score_threshold)
+
+ def collect_loss_level_single(self, cls_loss, reg_loss, assigned_gt_inds,
+ labels_seq):
+ """Get the average loss in each FPN level w.r.t. each gt label.
+
+ Args:
+ cls_loss (Tensor): Classification loss of each feature map pixel,
+ shape (num_anchor, num_class)
+ reg_loss (Tensor): Regression loss of each feature map pixel,
+ shape (num_anchor, 4)
+ assigned_gt_inds (Tensor): It indicates which gt the prior is
+ assigned to (0-based, -1: no assignment). shape (num_anchor),
+ labels_seq: The rank of labels. shape (num_gt)
+
+ Returns:
+ shape: (num_gt), average loss of each gt in this level
+ """
+ if len(reg_loss.shape) == 2: # iou loss has shape (num_prior, 4)
+ reg_loss = reg_loss.sum(dim=-1) # sum loss in tblr dims
+ if len(cls_loss.shape) == 2:
+ cls_loss = cls_loss.sum(dim=-1) # sum loss in class dims
+ loss = cls_loss + reg_loss
+ assert loss.size(0) == assigned_gt_inds.size(0)
+ # Default loss value is 1e6 for a layer where no anchor is positive
+ # to ensure it will not be chosen to back-propagate gradient
+ losses_ = loss.new_full(labels_seq.shape, 1e6)
+ for i, l in enumerate(labels_seq):
+ match = assigned_gt_inds == l
+ if match.any():
+ losses_[i] = loss[match].mean()
+ return losses_,
+
+ def reweight_loss_single(self, cls_loss, reg_loss, assigned_gt_inds,
+ labels, level, min_levels):
+ """Reweight loss values at each level.
+
+ Reassign loss values at each level by masking those where the
+ pre-calculated loss is too large. Then return the reduced losses.
+
+ Args:
+ cls_loss (Tensor): Element-wise classification loss.
+ Shape: (num_anchors, num_classes)
+ reg_loss (Tensor): Element-wise regression loss.
+ Shape: (num_anchors, 4)
+ assigned_gt_inds (Tensor): The gt indices that each anchor bbox
+ is assigned to. -1 denotes a negative anchor, otherwise it is the
+ gt index (0-based). Shape: (num_anchors, ),
+ labels (Tensor): Label assigned to anchors. Shape: (num_anchors, ).
+ level (int): The current level index in the pyramid
+ (0-4 for RetinaNet)
+ min_levels (Tensor): The best-matching level for each gt.
+ Shape: (num_gts, ),
+
+ Returns:
+ tuple:
+ - cls_loss: Reduced corrected classification loss. Scalar.
+ - reg_loss: Reduced corrected regression loss. Scalar.
+ - pos_flags (Tensor): Corrected bool tensor indicating the
+ final positive anchors. Shape: (num_anchors, ).
+ """
+ loc_weight = torch.ones_like(reg_loss)
+ cls_weight = torch.ones_like(cls_loss)
+ pos_flags = assigned_gt_inds >= 0 # positive pixel flag
+ pos_indices = torch.nonzero(pos_flags, as_tuple=False).flatten()
+
+ if pos_flags.any(): # pos pixels exist
+ pos_assigned_gt_inds = assigned_gt_inds[pos_flags]
+ zeroing_indices = (min_levels[pos_assigned_gt_inds] != level)
+ neg_indices = pos_indices[zeroing_indices]
+
+ if neg_indices.numel():
+ pos_flags[neg_indices] = 0
+ loc_weight[neg_indices] = 0
+ # Only the weight corresponding to the label is
+ # zeroed out if not selected
+ zeroing_labels = labels[neg_indices]
+ assert (zeroing_labels >= 0).all()
+ cls_weight[neg_indices, zeroing_labels] = 0
+
+ # Weighted loss for both cls and reg loss
+ cls_loss = weight_reduce_loss(cls_loss, cls_weight, reduction='sum')
+ reg_loss = weight_reduce_loss(reg_loss, loc_weight, reduction='sum')
+
+ return cls_loss, reg_loss, pos_flags
diff --git a/mmdet/models/dense_heads/ga_retina_head.py b/mmdet/models/dense_heads/ga_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8822d1ca78ee2fa2f304a0649e81274830383533
--- /dev/null
+++ b/mmdet/models/dense_heads/ga_retina_head.py
@@ -0,0 +1,109 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+from mmcv.ops import MaskedConv2d
+
+from ..builder import HEADS
+from .guided_anchor_head import FeatureAdaption, GuidedAnchorHead
+
+
+@HEADS.register_module()
+class GARetinaHead(GuidedAnchorHead):
+ """Guided-Anchor-based RetinaNet head."""
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(GARetinaHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.conv_loc = nn.Conv2d(self.feat_channels, 1, 1)
+ self.conv_shape = nn.Conv2d(self.feat_channels, self.num_anchors * 2,
+ 1)
+ self.feature_adaption_cls = FeatureAdaption(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.feature_adaption_reg = FeatureAdaption(
+ self.feat_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.retina_cls = MaskedConv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = MaskedConv2d(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the layer."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+
+ self.feature_adaption_cls.init_weights()
+ self.feature_adaption_reg.init_weights()
+
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.conv_loc, std=0.01, bias=bias_cls)
+ normal_init(self.conv_shape, std=0.01)
+ normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+ normal_init(self.retina_reg, std=0.01)
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level."""
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+
+ loc_pred = self.conv_loc(cls_feat)
+ shape_pred = self.conv_shape(reg_feat)
+
+ cls_feat = self.feature_adaption_cls(cls_feat, shape_pred)
+ reg_feat = self.feature_adaption_reg(reg_feat, shape_pred)
+
+ if not self.training:
+ mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
+ else:
+ mask = None
+ cls_score = self.retina_cls(cls_feat, mask)
+ bbox_pred = self.retina_reg(reg_feat, mask)
+ return cls_score, bbox_pred, shape_pred, loc_pred
diff --git a/mmdet/models/dense_heads/ga_rpn_head.py b/mmdet/models/dense_heads/ga_rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ec0d4fdd3475bfbd2e541a6e8130b1df9ad861a
--- /dev/null
+++ b/mmdet/models/dense_heads/ga_rpn_head.py
@@ -0,0 +1,171 @@
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv import ConfigDict
+from mmcv.cnn import normal_init
+from mmcv.ops import nms
+
+from ..builder import HEADS
+from .guided_anchor_head import GuidedAnchorHead
+from .rpn_test_mixin import RPNTestMixin
+
+
+@HEADS.register_module()
+class GARPNHead(RPNTestMixin, GuidedAnchorHead):
+ """Guided-Anchor-based RPN head."""
+
+ def __init__(self, in_channels, **kwargs):
+ super(GARPNHead, self).__init__(1, in_channels, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.rpn_conv = nn.Conv2d(
+ self.in_channels, self.feat_channels, 3, padding=1)
+ super(GARPNHead, self)._init_layers()
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ normal_init(self.rpn_conv, std=0.01)
+ super(GARPNHead, self).init_weights()
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level."""
+
+ x = self.rpn_conv(x)
+ x = F.relu(x, inplace=True)
+ (cls_score, bbox_pred, shape_pred,
+ loc_pred) = super(GARPNHead, self).forward_single(x)
+ return cls_score, bbox_pred, shape_pred, loc_pred
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ losses = super(GARPNHead, self).loss(
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ None,
+ img_metas,
+ gt_bboxes_ignore=gt_bboxes_ignore)
+ return dict(
+ loss_rpn_cls=losses['loss_cls'],
+ loss_rpn_bbox=losses['loss_bbox'],
+ loss_anchor_shape=losses['loss_shape'],
+ loss_anchor_loc=losses['loss_loc'])
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ mlvl_masks,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+
+ cfg = copy.deepcopy(cfg)
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You ' \
+ f'set max_num and max_per_img at the same time, ' \
+ f'but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ 'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set ' \
+ f'iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get ' \
+ f'{cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the ' \
+ f'nms_thr which will be deprecated.'
+
+ assert cfg.nms.get('type', 'nms') == 'nms', 'GARPNHead only support ' \
+ 'naive nms.'
+
+ mlvl_proposals = []
+ for idx in range(len(cls_scores)):
+ rpn_cls_score = cls_scores[idx]
+ rpn_bbox_pred = bbox_preds[idx]
+ anchors = mlvl_anchors[idx]
+ mask = mlvl_masks[idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ # if no location is kept, end.
+ if mask.sum() == 0:
+ continue
+ rpn_cls_score = rpn_cls_score.permute(1, 2, 0)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(-1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(-1, 2)
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ scores = rpn_cls_score.softmax(dim=1)[:, :-1]
+ # filter scores, bbox_pred w.r.t. mask.
+ # anchors are filtered in get_anchors() beforehand.
+ scores = scores[mask]
+ rpn_bbox_pred = rpn_bbox_pred.permute(1, 2, 0).reshape(-1,
+ 4)[mask, :]
+ if scores.dim() == 0:
+ rpn_bbox_pred = rpn_bbox_pred.unsqueeze(0)
+ anchors = anchors.unsqueeze(0)
+ scores = scores.unsqueeze(0)
+ # filter anchors, bbox_pred, scores w.r.t. scores
+ if cfg.nms_pre > 0 and scores.shape[0] > cfg.nms_pre:
+ _, topk_inds = scores.topk(cfg.nms_pre)
+ rpn_bbox_pred = rpn_bbox_pred[topk_inds, :]
+ anchors = anchors[topk_inds, :]
+ scores = scores[topk_inds]
+ # get proposals w.r.t. anchors and rpn_bbox_pred
+ proposals = self.bbox_coder.decode(
+ anchors, rpn_bbox_pred, max_shape=img_shape)
+ # filter out too small bboxes
+ if cfg.min_bbox_size > 0:
+ w = proposals[:, 2] - proposals[:, 0]
+ h = proposals[:, 3] - proposals[:, 1]
+ valid_inds = torch.nonzero(
+ (w >= cfg.min_bbox_size) & (h >= cfg.min_bbox_size),
+ as_tuple=False).squeeze()
+ proposals = proposals[valid_inds, :]
+ scores = scores[valid_inds]
+ # NMS in current level
+ proposals, _ = nms(proposals, scores, cfg.nms.iou_threshold)
+ proposals = proposals[:cfg.nms_post, :]
+ mlvl_proposals.append(proposals)
+ proposals = torch.cat(mlvl_proposals, 0)
+ if cfg.get('nms_across_levels', False):
+ # NMS across multi levels
+ proposals, _ = nms(proposals[:, :4], proposals[:, -1],
+ cfg.nms.iou_threshold)
+ proposals = proposals[:cfg.max_per_img, :]
+ else:
+ scores = proposals[:, 4]
+ num = min(cfg.max_per_img, proposals.shape[0])
+ _, topk_inds = scores.topk(num)
+ proposals = proposals[topk_inds, :]
+ return proposals
diff --git a/mmdet/models/dense_heads/gfl_head.py b/mmdet/models/dense_heads/gfl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..961bc92237663ad5343d3d08eb9c0e4e811ada05
--- /dev/null
+++ b/mmdet/models/dense_heads/gfl_head.py
@@ -0,0 +1,647 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, bbox2distance, bbox_overlaps,
+ build_assigner, build_sampler, distance2bbox,
+ images_to_levels, multi_apply, multiclass_nms,
+ reduce_mean, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+class Integral(nn.Module):
+ """A fixed layer for calculating integral result from distribution.
+
+ This layer calculates the target location by :math: `sum{P(y_i) * y_i}`,
+ P(y_i) denotes the softmax vector that represents the discrete distribution
+ y_i denotes the discrete set, usually {0, 1, 2, ..., reg_max}
+
+ Args:
+ reg_max (int): The maximal value of the discrete set. Default: 16. You
+ may want to reset it according to your new dataset or related
+ settings.
+ """
+
+ def __init__(self, reg_max=16):
+ super(Integral, self).__init__()
+ self.reg_max = reg_max
+ self.register_buffer('project',
+ torch.linspace(0, self.reg_max, self.reg_max + 1))
+
+ def forward(self, x):
+ """Forward feature from the regression head to get integral result of
+ bounding box location.
+
+ Args:
+ x (Tensor): Features of the regression head, shape (N, 4*(n+1)),
+ n is self.reg_max.
+
+ Returns:
+ x (Tensor): Integral result of box locations, i.e., distance
+ offsets from the box center in four directions, shape (N, 4).
+ """
+ x = F.softmax(x.reshape(-1, self.reg_max + 1), dim=1)
+ x = F.linear(x, self.project.type_as(x)).reshape(-1, 4)
+ return x
+
+
+@HEADS.register_module()
+class GFLHead(AnchorHead):
+ """Generalized Focal Loss: Learning Qualified and Distributed Bounding
+ Boxes for Dense Object Detection.
+
+ GFL head structure is similar with ATSS, however GFL uses
+ 1) joint representation for classification and localization quality, and
+ 2) flexible General distribution for bounding box locations,
+ which are supervised by
+ Quality Focal Loss (QFL) and Distribution Focal Loss (DFL), respectively
+
+ https://arxiv.org/abs/2006.04388
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of conv layers in cls and reg tower.
+ Default: 4.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='GN', num_groups=32, requires_grad=True).
+ loss_qfl (dict): Config of Quality Focal Loss (QFL).
+ reg_max (int): Max value of integral set :math: `{0, ..., reg_max}`
+ in QFL setting. Default: 16.
+ Example:
+ >>> self = GFLHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_quality_score, bbox_pred = self.forward(feats)
+ >>> assert len(cls_quality_score) == len(self.scales)
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ loss_dfl=dict(type='DistributionFocalLoss', loss_weight=0.25),
+ reg_max=16,
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.reg_max = reg_max
+ super(GFLHead, self).__init__(num_classes, in_channels, **kwargs)
+
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.integral = Integral(self.reg_max)
+ self.loss_dfl = build_loss(loss_dfl)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ assert self.num_anchors == 1, 'anchor free version'
+ self.gfl_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.gfl_reg = nn.Conv2d(
+ self.feat_channels, 4 * (self.reg_max + 1), 3, padding=1)
+ self.scales = nn.ModuleList(
+ [Scale(1.0) for _ in self.anchor_generator.strides])
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.gfl_cls, std=0.01, bias=bias_cls)
+ normal_init(self.gfl_reg, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification and quality (IoU)
+ joint scores for all scale levels, each is a 4D-tensor,
+ the channel number is num_classes.
+ bbox_preds (list[Tensor]): Box distribution logits for all
+ scale levels, each is a 4D-tensor, the channel number is
+ 4*(n+1), n is max value of integral set.
+ """
+ return multi_apply(self.forward_single, feats, self.scales)
+
+ def forward_single(self, x, scale):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls and quality joint scores for a single
+ scale level the channel number is num_classes.
+ bbox_pred (Tensor): Box distribution logits for a single scale
+ level, the channel number is 4*(n+1), n is max value of
+ integral set.
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.gfl_cls(cls_feat)
+ bbox_pred = scale(self.gfl_reg(reg_feat)).float()
+ return cls_score, bbox_pred
+
+ def anchor_center(self, anchors):
+ """Get anchor centers from anchors.
+
+ Args:
+ anchors (Tensor): Anchor list with shape (N, 4), "xyxy" format.
+
+ Returns:
+ Tensor: Anchor centers with shape (N, 2), "xy" format.
+ """
+ anchors_cx = (anchors[..., 2] + anchors[..., 0]) / 2
+ anchors_cy = (anchors[..., 3] + anchors[..., 1]) / 2
+ return torch.stack([anchors_cx, anchors_cy], dim=-1)
+
+ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
+ bbox_targets, stride, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ cls_score (Tensor): Cls and quality joint scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_pred (Tensor): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor wight
+ shape (N, num_total_anchors, 4).
+ stride (tuple): Stride in this scale level.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert stride[0] == stride[1], 'h stride is not equal to w stride!'
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(-1, 4 * (self.reg_max + 1))
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+ score = label_weights.new_zeros(labels.shape)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
+
+ weight_targets = cls_score.detach().sigmoid()
+ weight_targets = weight_targets.max(dim=1)[0][pos_inds]
+ pos_bbox_pred_corners = self.integral(pos_bbox_pred)
+ pos_decode_bbox_pred = distance2bbox(pos_anchor_centers,
+ pos_bbox_pred_corners)
+ pos_decode_bbox_targets = pos_bbox_targets / stride[0]
+ score[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
+ target_corners = bbox2distance(pos_anchor_centers,
+ pos_decode_bbox_targets,
+ self.reg_max).reshape(-1)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=weight_targets,
+ avg_factor=1.0)
+
+ # dfl loss
+ loss_dfl = self.loss_dfl(
+ pred_corners,
+ target_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+ else:
+ loss_bbox = bbox_pred.sum() * 0
+ loss_dfl = bbox_pred.sum() * 0
+ weight_targets = bbox_pred.new_tensor(0)
+
+ # cls (qfl) loss
+ loss_cls = self.loss_cls(
+ cls_score, (labels, score),
+ weight=label_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox, loss_dfl, weight_targets.sum()
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Cls and quality scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_preds (list[Tensor]): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, losses_dfl,\
+ avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ self.anchor_generator.strides,
+ num_total_samples=num_total_samples)
+
+ avg_factor = sum(avg_factor)
+ avg_factor = reduce_mean(avg_factor).item()
+ losses_bbox = list(map(lambda x: x / avg_factor, losses_bbox))
+ losses_dfl = list(map(lambda x: x / avg_factor, losses_dfl))
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_dfl=losses_dfl)
+
+ def _get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into labeled boxes.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for a single scale level
+ has shape (N, num_classes, H, W).
+ bbox_preds (list[Tensor]): Box distribution logits for a single
+ scale level with shape (N, 4*(n+1), H, W), n is max value of
+ integral set.
+ mlvl_anchors (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_anchors, 4).
+ img_shapes (list[tuple[int]]): Shape of the input image,
+ list[(height, width, 3)].
+ scale_factors (list[ndarray]): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+ batch_size = cls_scores[0].shape[0]
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, stride, anchors in zip(
+ cls_scores, bbox_preds, self.anchor_generator.strides,
+ mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ assert stride[0] == stride[1]
+ scores = cls_score.permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.cls_out_channels).sigmoid()
+ bbox_pred = bbox_pred.permute(0, 2, 3, 1)
+
+ bbox_pred = self.integral(bbox_pred) * stride[0]
+ bbox_pred = bbox_pred.reshape(batch_size, -1, 4)
+
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[1] > nms_pre:
+ max_scores, _ = scores.max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+ else:
+ anchors = anchors.expand_as(bbox_pred)
+
+ bboxes = distance2bbox(
+ self.anchor_center(anchors), bbox_pred, max_shape=img_shapes)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1], 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+
+ if with_nms:
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
+ batch_mlvl_scores):
+ det_bbox, det_label = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ det_results.append(tuple([det_bbox, det_label]))
+ else:
+ det_results = [
+ tuple(mlvl_bs)
+ for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores)
+ ]
+ return det_results
+
+ def get_targets(self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True):
+ """Get targets for GFL head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. Besides
+ returning the targets as the parent method does, it also returns the
+ anchors as the first element of the returned tuple.
+ """
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ num_level_anchors_list = [num_level_anchors] * num_imgs
+
+ # concat all level anchors and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ anchor_list[i] = torch.cat(anchor_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_anchors, all_labels, all_label_weights, all_bbox_targets,
+ all_bbox_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ anchor_list,
+ valid_flag_list,
+ num_level_anchors_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ anchors_list = images_to_levels(all_anchors, num_level_anchors)
+ labels_list = images_to_levels(all_labels, num_level_anchors)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_anchors)
+ bbox_targets_list = images_to_levels(all_bbox_targets,
+ num_level_anchors)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_anchors)
+ return (anchors_list, labels_list, label_weights_list,
+ bbox_targets_list, bbox_weights_list, num_total_pos,
+ num_total_neg)
+
+ def _get_target_single(self,
+ flat_anchors,
+ valid_flags,
+ num_level_anchors,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression, classification targets for anchors in a single
+ image.
+
+ Args:
+ flat_anchors (Tensor): Multi-level anchors of the image, which are
+ concatenated into a single tensor of shape (num_anchors, 4)
+ valid_flags (Tensor): Multi level valid flags of the image,
+ which are concatenated into a single tensor of
+ shape (num_anchors,).
+ num_level_anchors Tensor): Number of anchors of each scale level.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: N is the number of total anchors in the image.
+ anchors (Tensor): All anchors in the image with shape (N, 4).
+ labels (Tensor): Labels of all anchors in the image with shape
+ (N,).
+ label_weights (Tensor): Label weights of all anchor in the
+ image with shape (N,).
+ bbox_targets (Tensor): BBox targets of all anchors in the
+ image with shape (N, 4).
+ bbox_weights (Tensor): BBox weights of all anchors in the
+ image with shape (N, 4).
+ pos_inds (Tensor): Indices of positive anchor with shape
+ (num_pos,).
+ neg_inds (Tensor): Indices of negative anchor with shape
+ (num_neg,).
+ """
+ inside_flags = anchor_inside_flags(flat_anchors, valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample anchors
+ anchors = flat_anchors[inside_flags, :]
+
+ num_level_anchors_inside = self.get_num_level_anchors_inside(
+ num_level_anchors, inside_flags)
+ assign_result = self.assigner.assign(anchors, num_level_anchors_inside,
+ gt_bboxes, gt_bboxes_ignore,
+ gt_labels)
+
+ sampling_result = self.sampler.sample(assign_result, anchors,
+ gt_bboxes)
+
+ num_valid_anchors = anchors.shape[0]
+ bbox_targets = torch.zeros_like(anchors)
+ bbox_weights = torch.zeros_like(anchors)
+ labels = anchors.new_full((num_valid_anchors, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = anchors.new_zeros(num_valid_anchors, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ pos_bbox_targets = sampling_result.pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_anchors.size(0)
+ anchors = unmap(anchors, num_total_anchors, inside_flags)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_targets = unmap(bbox_targets, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (anchors, labels, label_weights, bbox_targets, bbox_weights,
+ pos_inds, neg_inds)
+
+ def get_num_level_anchors_inside(self, num_level_anchors, inside_flags):
+ split_inside_flags = torch.split(inside_flags, num_level_anchors)
+ num_level_anchors_inside = [
+ int(flags.sum()) for flags in split_inside_flags
+ ]
+ return num_level_anchors_inside
diff --git a/mmdet/models/dense_heads/guided_anchor_head.py b/mmdet/models/dense_heads/guided_anchor_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..997ebb751ade2ebae3fce335a08c46f596c60913
--- /dev/null
+++ b/mmdet/models/dense_heads/guided_anchor_head.py
@@ -0,0 +1,860 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import bias_init_with_prob, normal_init
+from mmcv.ops import DeformConv2d, MaskedConv2d
+from mmcv.runner import force_fp32
+
+from mmdet.core import (anchor_inside_flags, build_anchor_generator,
+ build_assigner, build_bbox_coder, build_sampler,
+ calc_region, images_to_levels, multi_apply,
+ multiclass_nms, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+class FeatureAdaption(nn.Module):
+ """Feature Adaption Module.
+
+ Feature Adaption Module is implemented based on DCN v1.
+ It uses anchor shape prediction rather than feature map to
+ predict offsets of deform conv layer.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ out_channels (int): Number of channels in the output feature map.
+ kernel_size (int): Deformable conv kernel size.
+ deform_groups (int): Deformable conv group size.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ deform_groups=4):
+ super(FeatureAdaption, self).__init__()
+ offset_channels = kernel_size * kernel_size * 2
+ self.conv_offset = nn.Conv2d(
+ 2, deform_groups * offset_channels, 1, bias=False)
+ self.conv_adaption = DeformConv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ padding=(kernel_size - 1) // 2,
+ deform_groups=deform_groups)
+ self.relu = nn.ReLU(inplace=True)
+
+ def init_weights(self):
+ normal_init(self.conv_offset, std=0.1)
+ normal_init(self.conv_adaption, std=0.01)
+
+ def forward(self, x, shape):
+ offset = self.conv_offset(shape.detach())
+ x = self.relu(self.conv_adaption(x, offset))
+ return x
+
+
+@HEADS.register_module()
+class GuidedAnchorHead(AnchorHead):
+ """Guided-Anchor-based head (GA-RPN, GA-RetinaNet, etc.).
+
+ This GuidedAnchorHead will predict high-quality feature guided
+ anchors and locations where anchors will be kept in inference.
+ There are mainly 3 categories of bounding-boxes.
+
+ - Sampled 9 pairs for target assignment. (approxes)
+ - The square boxes where the predicted anchors are based on. (squares)
+ - Guided anchors.
+
+ Please refer to https://arxiv.org/abs/1901.03278 for more details.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of channels in the input feature map.
+ feat_channels (int): Number of hidden channels.
+ approx_anchor_generator (dict): Config dict for approx generator
+ square_anchor_generator (dict): Config dict for square generator
+ anchor_coder (dict): Config dict for anchor coder
+ bbox_coder (dict): Config dict for bbox coder
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ deform_groups: (int): Group number of DCN in
+ FeatureAdaption module.
+ loc_filter_thr (float): Threshold to filter out unconcerned regions.
+ loss_loc (dict): Config of location loss.
+ loss_shape (dict): Config of anchor shape loss.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of bbox regression loss.
+ """
+
+ def __init__(
+ self,
+ num_classes,
+ in_channels,
+ feat_channels=256,
+ approx_anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=8,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[4, 8, 16, 32, 64]),
+ square_anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ scales=[8],
+ strides=[4, 8, 16, 32, 64]),
+ anchor_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]
+ ),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0]
+ ),
+ reg_decoded_bbox=False,
+ deform_groups=4,
+ loc_filter_thr=0.01,
+ train_cfg=None,
+ test_cfg=None,
+ loss_loc=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_shape=dict(type='BoundedIoULoss', beta=0.2, loss_weight=1.0),
+ loss_cls=dict(
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
+ loss_bbox=dict(type='SmoothL1Loss', beta=1.0,
+ loss_weight=1.0)): # yapf: disable
+ super(AnchorHead, self).__init__()
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.deform_groups = deform_groups
+ self.loc_filter_thr = loc_filter_thr
+
+ # build approx_anchor_generator and square_anchor_generator
+ assert (approx_anchor_generator['octave_base_scale'] ==
+ square_anchor_generator['scales'][0])
+ assert (approx_anchor_generator['strides'] ==
+ square_anchor_generator['strides'])
+ self.approx_anchor_generator = build_anchor_generator(
+ approx_anchor_generator)
+ self.square_anchor_generator = build_anchor_generator(
+ square_anchor_generator)
+ self.approxs_per_octave = self.approx_anchor_generator \
+ .num_base_anchors[0]
+
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ # one anchor per location
+ self.num_anchors = 1
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ self.loc_focal_loss = loss_loc['type'] in ['FocalLoss']
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ self.ga_sampling = train_cfg is not None and hasattr(
+ train_cfg, 'ga_sampler')
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+
+ # build bbox_coder
+ self.anchor_coder = build_bbox_coder(anchor_coder)
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+
+ # build losses
+ self.loss_loc = build_loss(loss_loc)
+ self.loss_shape = build_loss(loss_shape)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.ga_assigner = build_assigner(self.train_cfg.ga_assigner)
+ if self.ga_sampling:
+ ga_sampler_cfg = self.train_cfg.ga_sampler
+ else:
+ ga_sampler_cfg = dict(type='PseudoSampler')
+ self.ga_sampler = build_sampler(ga_sampler_cfg, context=self)
+
+ self.fp16_enabled = False
+
+ self._init_layers()
+
+ def _init_layers(self):
+ self.relu = nn.ReLU(inplace=True)
+ self.conv_loc = nn.Conv2d(self.in_channels, 1, 1)
+ self.conv_shape = nn.Conv2d(self.in_channels, self.num_anchors * 2, 1)
+ self.feature_adaption = FeatureAdaption(
+ self.in_channels,
+ self.feat_channels,
+ kernel_size=3,
+ deform_groups=self.deform_groups)
+ self.conv_cls = MaskedConv2d(self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 1)
+ self.conv_reg = MaskedConv2d(self.feat_channels, self.num_anchors * 4,
+ 1)
+
+ def init_weights(self):
+ normal_init(self.conv_cls, std=0.01)
+ normal_init(self.conv_reg, std=0.01)
+
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.conv_loc, std=0.01, bias=bias_cls)
+ normal_init(self.conv_shape, std=0.01)
+
+ self.feature_adaption.init_weights()
+
+ def forward_single(self, x):
+ loc_pred = self.conv_loc(x)
+ shape_pred = self.conv_shape(x)
+ x = self.feature_adaption(x, shape_pred)
+ # masked conv is only used during inference for speed-up
+ if not self.training:
+ mask = loc_pred.sigmoid()[0] >= self.loc_filter_thr
+ else:
+ mask = None
+ cls_score = self.conv_cls(x, mask)
+ bbox_pred = self.conv_reg(x, mask)
+ return cls_score, bbox_pred, shape_pred, loc_pred
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def get_sampled_approxs(self, featmap_sizes, img_metas, device='cuda'):
+ """Get sampled approxs and inside flags according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: approxes of each image, inside flags of each image
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # approxes for one time
+ multi_level_approxs = self.approx_anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+ approxs_list = [multi_level_approxs for _ in range(num_imgs)]
+
+ # for each image, we compute inside flags of multi level approxes
+ inside_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = []
+ multi_level_approxs = approxs_list[img_id]
+
+ # obtain valid flags for each approx first
+ multi_level_approx_flags = self.approx_anchor_generator \
+ .valid_flags(featmap_sizes,
+ img_meta['pad_shape'],
+ device=device)
+
+ for i, flags in enumerate(multi_level_approx_flags):
+ approxs = multi_level_approxs[i]
+ inside_flags_list = []
+ for i in range(self.approxs_per_octave):
+ split_valid_flags = flags[i::self.approxs_per_octave]
+ split_approxs = approxs[i::self.approxs_per_octave, :]
+ inside_flags = anchor_inside_flags(
+ split_approxs, split_valid_flags,
+ img_meta['img_shape'][:2],
+ self.train_cfg.allowed_border)
+ inside_flags_list.append(inside_flags)
+ # inside_flag for a position is true if any anchor in this
+ # position is true
+ inside_flags = (
+ torch.stack(inside_flags_list, 0).sum(dim=0) > 0)
+ multi_level_flags.append(inside_flags)
+ inside_flag_list.append(multi_level_flags)
+ return approxs_list, inside_flag_list
+
+ def get_anchors(self,
+ featmap_sizes,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ use_loc_filter=False,
+ device='cuda'):
+ """Get squares according to feature map sizes and guided anchors.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ shape_preds (list[tensor]): Multi-level shape predictions.
+ loc_preds (list[tensor]): Multi-level location predictions.
+ img_metas (list[dict]): Image meta info.
+ use_loc_filter (bool): Use loc filter or not.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: square approxs of each image, guided anchors of each image,
+ loc masks of each image
+ """
+ num_imgs = len(img_metas)
+ num_levels = len(featmap_sizes)
+
+ # since feature map sizes of all images are the same, we only compute
+ # squares for one time
+ multi_level_squares = self.square_anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+ squares_list = [multi_level_squares for _ in range(num_imgs)]
+
+ # for each image, we compute multi level guided anchors
+ guided_anchors_list = []
+ loc_mask_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_guided_anchors = []
+ multi_level_loc_mask = []
+ for i in range(num_levels):
+ squares = squares_list[img_id][i]
+ shape_pred = shape_preds[i][img_id]
+ loc_pred = loc_preds[i][img_id]
+ guided_anchors, loc_mask = self._get_guided_anchors_single(
+ squares,
+ shape_pred,
+ loc_pred,
+ use_loc_filter=use_loc_filter)
+ multi_level_guided_anchors.append(guided_anchors)
+ multi_level_loc_mask.append(loc_mask)
+ guided_anchors_list.append(multi_level_guided_anchors)
+ loc_mask_list.append(multi_level_loc_mask)
+ return squares_list, guided_anchors_list, loc_mask_list
+
+ def _get_guided_anchors_single(self,
+ squares,
+ shape_pred,
+ loc_pred,
+ use_loc_filter=False):
+ """Get guided anchors and loc masks for a single level.
+
+ Args:
+ square (tensor): Squares of a single level.
+ shape_pred (tensor): Shape predections of a single level.
+ loc_pred (tensor): Loc predections of a single level.
+ use_loc_filter (list[tensor]): Use loc filter or not.
+
+ Returns:
+ tuple: guided anchors, location masks
+ """
+ # calculate location filtering mask
+ loc_pred = loc_pred.sigmoid().detach()
+ if use_loc_filter:
+ loc_mask = loc_pred >= self.loc_filter_thr
+ else:
+ loc_mask = loc_pred >= 0.0
+ mask = loc_mask.permute(1, 2, 0).expand(-1, -1, self.num_anchors)
+ mask = mask.contiguous().view(-1)
+ # calculate guided anchors
+ squares = squares[mask]
+ anchor_deltas = shape_pred.permute(1, 2, 0).contiguous().view(
+ -1, 2).detach()[mask]
+ bbox_deltas = anchor_deltas.new_full(squares.size(), 0)
+ bbox_deltas[:, 2:] = anchor_deltas
+ guided_anchors = self.anchor_coder.decode(
+ squares, bbox_deltas, wh_ratio_clip=1e-6)
+ return guided_anchors, mask
+
+ def ga_loc_targets(self, gt_bboxes_list, featmap_sizes):
+ """Compute location targets for guided anchoring.
+
+ Each feature map is divided into positive, negative and ignore regions.
+ - positive regions: target 1, weight 1
+ - ignore regions: target 0, weight 0
+ - negative regions: target 0, weight 0.1
+
+ Args:
+ gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
+ featmap_sizes (list[tuple]): Multi level sizes of each feature
+ maps.
+
+ Returns:
+ tuple
+ """
+ anchor_scale = self.approx_anchor_generator.octave_base_scale
+ anchor_strides = self.approx_anchor_generator.strides
+ # Currently only supports same stride in x and y direction.
+ for stride in anchor_strides:
+ assert (stride[0] == stride[1])
+ anchor_strides = [stride[0] for stride in anchor_strides]
+
+ center_ratio = self.train_cfg.center_ratio
+ ignore_ratio = self.train_cfg.ignore_ratio
+ img_per_gpu = len(gt_bboxes_list)
+ num_lvls = len(featmap_sizes)
+ r1 = (1 - center_ratio) / 2
+ r2 = (1 - ignore_ratio) / 2
+ all_loc_targets = []
+ all_loc_weights = []
+ all_ignore_map = []
+ for lvl_id in range(num_lvls):
+ h, w = featmap_sizes[lvl_id]
+ loc_targets = torch.zeros(
+ img_per_gpu,
+ 1,
+ h,
+ w,
+ device=gt_bboxes_list[0].device,
+ dtype=torch.float32)
+ loc_weights = torch.full_like(loc_targets, -1)
+ ignore_map = torch.zeros_like(loc_targets)
+ all_loc_targets.append(loc_targets)
+ all_loc_weights.append(loc_weights)
+ all_ignore_map.append(ignore_map)
+ for img_id in range(img_per_gpu):
+ gt_bboxes = gt_bboxes_list[img_id]
+ scale = torch.sqrt((gt_bboxes[:, 2] - gt_bboxes[:, 0]) *
+ (gt_bboxes[:, 3] - gt_bboxes[:, 1]))
+ min_anchor_size = scale.new_full(
+ (1, ), float(anchor_scale * anchor_strides[0]))
+ # assign gt bboxes to different feature levels w.r.t. their scales
+ target_lvls = torch.floor(
+ torch.log2(scale) - torch.log2(min_anchor_size) + 0.5)
+ target_lvls = target_lvls.clamp(min=0, max=num_lvls - 1).long()
+ for gt_id in range(gt_bboxes.size(0)):
+ lvl = target_lvls[gt_id].item()
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[lvl]
+ # calculate ignore regions
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[lvl])
+ # calculate positive (center) regions
+ ctr_x1, ctr_y1, ctr_x2, ctr_y2 = calc_region(
+ gt_, r1, featmap_sizes[lvl])
+ all_loc_targets[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
+ ctr_x1:ctr_x2 + 1] = 1
+ all_loc_weights[lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 0
+ all_loc_weights[lvl][img_id, 0, ctr_y1:ctr_y2 + 1,
+ ctr_x1:ctr_x2 + 1] = 1
+ # calculate ignore map on nearby low level feature
+ if lvl > 0:
+ d_lvl = lvl - 1
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[d_lvl]
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[d_lvl])
+ all_ignore_map[d_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 1
+ # calculate ignore map on nearby high level feature
+ if lvl < num_lvls - 1:
+ u_lvl = lvl + 1
+ # rescaled to corresponding feature map
+ gt_ = gt_bboxes[gt_id, :4] / anchor_strides[u_lvl]
+ ignore_x1, ignore_y1, ignore_x2, ignore_y2 = calc_region(
+ gt_, r2, featmap_sizes[u_lvl])
+ all_ignore_map[u_lvl][img_id, 0, ignore_y1:ignore_y2 + 1,
+ ignore_x1:ignore_x2 + 1] = 1
+ for lvl_id in range(num_lvls):
+ # ignore negative regions w.r.t. ignore map
+ all_loc_weights[lvl_id][(all_loc_weights[lvl_id] < 0)
+ & (all_ignore_map[lvl_id] > 0)] = 0
+ # set negative regions with weight 0.1
+ all_loc_weights[lvl_id][all_loc_weights[lvl_id] < 0] = 0.1
+ # loc average factor to balance loss
+ loc_avg_factor = sum(
+ [t.size(0) * t.size(-1) * t.size(-2)
+ for t in all_loc_targets]) / 200
+ return all_loc_targets, all_loc_weights, loc_avg_factor
+
+ def _ga_shape_target_single(self,
+ flat_approxs,
+ inside_flags,
+ flat_squares,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ img_meta,
+ unmap_outputs=True):
+ """Compute guided anchoring targets.
+
+ This function returns sampled anchors and gt bboxes directly
+ rather than calculates regression targets.
+
+ Args:
+ flat_approxs (Tensor): flat approxs of a single image,
+ shape (n, 4)
+ inside_flags (Tensor): inside flags of a single image,
+ shape (n, ).
+ flat_squares (Tensor): flat squares of a single image,
+ shape (approxs_per_octave * n, 4)
+ gt_bboxes (Tensor): Ground truth bboxes of a single image.
+ img_meta (dict): Meta info of a single image.
+ approxs_per_octave (int): number of approxs per octave
+ cfg (dict): RPN train configs.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple
+ """
+ if not inside_flags.any():
+ return (None, ) * 5
+ # assign gt and sample anchors
+ expand_inside_flags = inside_flags[:, None].expand(
+ -1, self.approxs_per_octave).reshape(-1)
+ approxs = flat_approxs[expand_inside_flags, :]
+ squares = flat_squares[inside_flags, :]
+
+ assign_result = self.ga_assigner.assign(approxs, squares,
+ self.approxs_per_octave,
+ gt_bboxes, gt_bboxes_ignore)
+ sampling_result = self.ga_sampler.sample(assign_result, squares,
+ gt_bboxes)
+
+ bbox_anchors = torch.zeros_like(squares)
+ bbox_gts = torch.zeros_like(squares)
+ bbox_weights = torch.zeros_like(squares)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ bbox_anchors[pos_inds, :] = sampling_result.pos_bboxes
+ bbox_gts[pos_inds, :] = sampling_result.pos_gt_bboxes
+ bbox_weights[pos_inds, :] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_squares.size(0)
+ bbox_anchors = unmap(bbox_anchors, num_total_anchors, inside_flags)
+ bbox_gts = unmap(bbox_gts, num_total_anchors, inside_flags)
+ bbox_weights = unmap(bbox_weights, num_total_anchors, inside_flags)
+
+ return (bbox_anchors, bbox_gts, bbox_weights, pos_inds, neg_inds)
+
+ def ga_shape_targets(self,
+ approx_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ unmap_outputs=True):
+ """Compute guided anchoring targets.
+
+ Args:
+ approx_list (list[list]): Multi level approxs of each image.
+ inside_flag_list (list[list]): Multi level inside flags of each
+ image.
+ square_list (list[list]): Multi level squares of each image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple
+ """
+ num_imgs = len(img_metas)
+ assert len(approx_list) == len(inside_flag_list) == len(
+ square_list) == num_imgs
+ # anchor number of multi levels
+ num_level_squares = [squares.size(0) for squares in square_list[0]]
+ # concat all level anchors and flags to a single tensor
+ inside_flag_flat_list = []
+ approx_flat_list = []
+ square_flat_list = []
+ for i in range(num_imgs):
+ assert len(square_list[i]) == len(inside_flag_list[i])
+ inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
+ approx_flat_list.append(torch.cat(approx_list[i]))
+ square_flat_list.append(torch.cat(square_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ (all_bbox_anchors, all_bbox_gts, all_bbox_weights, pos_inds_list,
+ neg_inds_list) = multi_apply(
+ self._ga_shape_target_single,
+ approx_flat_list,
+ inside_flag_flat_list,
+ square_flat_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ img_metas,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([bbox_anchors is None for bbox_anchors in all_bbox_anchors]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ bbox_anchors_list = images_to_levels(all_bbox_anchors,
+ num_level_squares)
+ bbox_gts_list = images_to_levels(all_bbox_gts, num_level_squares)
+ bbox_weights_list = images_to_levels(all_bbox_weights,
+ num_level_squares)
+ return (bbox_anchors_list, bbox_gts_list, bbox_weights_list,
+ num_total_pos, num_total_neg)
+
+ def loss_shape_single(self, shape_pred, bbox_anchors, bbox_gts,
+ anchor_weights, anchor_total_num):
+ shape_pred = shape_pred.permute(0, 2, 3, 1).contiguous().view(-1, 2)
+ bbox_anchors = bbox_anchors.contiguous().view(-1, 4)
+ bbox_gts = bbox_gts.contiguous().view(-1, 4)
+ anchor_weights = anchor_weights.contiguous().view(-1, 4)
+ bbox_deltas = bbox_anchors.new_full(bbox_anchors.size(), 0)
+ bbox_deltas[:, 2:] += shape_pred
+ # filter out negative samples to speed-up weighted_bounded_iou_loss
+ inds = torch.nonzero(
+ anchor_weights[:, 0] > 0, as_tuple=False).squeeze(1)
+ bbox_deltas_ = bbox_deltas[inds]
+ bbox_anchors_ = bbox_anchors[inds]
+ bbox_gts_ = bbox_gts[inds]
+ anchor_weights_ = anchor_weights[inds]
+ pred_anchors_ = self.anchor_coder.decode(
+ bbox_anchors_, bbox_deltas_, wh_ratio_clip=1e-6)
+ loss_shape = self.loss_shape(
+ pred_anchors_,
+ bbox_gts_,
+ anchor_weights_,
+ avg_factor=anchor_total_num)
+ return loss_shape
+
+ def loss_loc_single(self, loc_pred, loc_target, loc_weight,
+ loc_avg_factor):
+ loss_loc = self.loss_loc(
+ loc_pred.reshape(-1, 1),
+ loc_target.reshape(-1).long(),
+ loc_weight.reshape(-1),
+ avg_factor=loc_avg_factor)
+ return loss_loc
+
+ @force_fp32(
+ apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ # get loc targets
+ loc_targets, loc_weights, loc_avg_factor = self.ga_loc_targets(
+ gt_bboxes, featmap_sizes)
+
+ # get sampled approxes
+ approxs_list, inside_flag_list = self.get_sampled_approxs(
+ featmap_sizes, img_metas, device=device)
+ # get squares and guided anchors
+ squares_list, guided_anchors_list, _ = self.get_anchors(
+ featmap_sizes, shape_preds, loc_preds, img_metas, device=device)
+
+ # get shape targets
+ shape_targets = self.ga_shape_targets(approxs_list, inside_flag_list,
+ squares_list, gt_bboxes,
+ img_metas)
+ if shape_targets is None:
+ return None
+ (bbox_anchors_list, bbox_gts_list, anchor_weights_list, anchor_fg_num,
+ anchor_bg_num) = shape_targets
+ anchor_total_num = (
+ anchor_fg_num if not self.ga_sampling else anchor_fg_num +
+ anchor_bg_num)
+
+ # get anchor targets
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ guided_anchors_list,
+ inside_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [
+ anchors.size(0) for anchors in guided_anchors_list[0]
+ ]
+ # concat all level anchors to a single tensor
+ concat_anchor_list = []
+ for i in range(len(guided_anchors_list)):
+ concat_anchor_list.append(torch.cat(guided_anchors_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ # get classification and bbox regression losses
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ # get anchor location loss
+ losses_loc = []
+ for i in range(len(loc_preds)):
+ loss_loc = self.loss_loc_single(
+ loc_preds[i],
+ loc_targets[i],
+ loc_weights[i],
+ loc_avg_factor=loc_avg_factor)
+ losses_loc.append(loss_loc)
+
+ # get anchor shape loss
+ losses_shape = []
+ for i in range(len(shape_preds)):
+ loss_shape = self.loss_shape_single(
+ shape_preds[i],
+ bbox_anchors_list[i],
+ bbox_gts_list[i],
+ anchor_weights_list[i],
+ anchor_total_num=anchor_total_num)
+ losses_shape.append(loss_shape)
+
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_shape=losses_shape,
+ loss_loc=losses_loc)
+
+ @force_fp32(
+ apply_to=('cls_scores', 'bbox_preds', 'shape_preds', 'loc_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ assert len(cls_scores) == len(bbox_preds) == len(shape_preds) == len(
+ loc_preds)
+ num_levels = len(cls_scores)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ device = cls_scores[0].device
+ # get guided anchors
+ _, guided_anchors, loc_masks = self.get_anchors(
+ featmap_sizes,
+ shape_preds,
+ loc_preds,
+ img_metas,
+ use_loc_filter=not self.training,
+ device=device)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ guided_anchor_list = [
+ guided_anchors[img_id][i].detach() for i in range(num_levels)
+ ]
+ loc_mask_list = [
+ loc_masks[img_id][i].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ guided_anchor_list,
+ loc_mask_list, img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ mlvl_masks,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, anchors, mask in zip(cls_scores, bbox_preds,
+ mlvl_anchors,
+ mlvl_masks):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ # if no location is kept, end.
+ if mask.sum() == 0:
+ continue
+ # reshape scores and bbox_pred
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ # filter scores, bbox_pred w.r.t. mask.
+ # anchors are filtered in get_anchors() beforehand.
+ scores = scores[mask, :]
+ bbox_pred = bbox_pred[mask, :]
+ if scores.dim() == 0:
+ anchors = anchors.unsqueeze(0)
+ scores = scores.unsqueeze(0)
+ bbox_pred = bbox_pred.unsqueeze(0)
+ # filter anchors, bbox_pred, scores w.r.t. scores
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ # multi class NMS
+ det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/ld_head.py b/mmdet/models/dense_heads/ld_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..501e1f7befa086f0b2f818531807411fc383d7bd
--- /dev/null
+++ b/mmdet/models/dense_heads/ld_head.py
@@ -0,0 +1,261 @@
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import (bbox2distance, bbox_overlaps, distance2bbox,
+ multi_apply, reduce_mean)
+from ..builder import HEADS, build_loss
+from .gfl_head import GFLHead
+
+
+@HEADS.register_module()
+class LDHead(GFLHead):
+ """Localization distillation Head. (Short description)
+
+ It utilizes the learned bbox distributions to transfer the localization
+ dark knowledge from teacher to student. Original paper: `Localization
+ Distillation for Object Detection. `_
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ loss_ld (dict): Config of Localization Distillation Loss (LD),
+ T is the temperature for distillation.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ loss_ld=dict(
+ type='LocalizationDistillationLoss',
+ loss_weight=0.25,
+ T=10),
+ **kwargs):
+
+ super(LDHead, self).__init__(num_classes, in_channels, **kwargs)
+ self.loss_ld = build_loss(loss_ld)
+
+ def loss_single(self, anchors, cls_score, bbox_pred, labels, label_weights,
+ bbox_targets, stride, soft_targets, num_total_samples):
+ """Compute loss of a single scale level.
+
+ Args:
+ anchors (Tensor): Box reference for each scale level with shape
+ (N, num_total_anchors, 4).
+ cls_score (Tensor): Cls and quality joint scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_pred (Tensor): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ labels (Tensor): Labels of each anchors with shape
+ (N, num_total_anchors).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (N, num_total_anchors)
+ bbox_targets (Tensor): BBox regression targets of each anchor wight
+ shape (N, num_total_anchors, 4).
+ stride (tuple): Stride in this scale level.
+ num_total_samples (int): Number of positive samples that is
+ reduced over all GPUs.
+
+ Returns:
+ dict[tuple, Tensor]: Loss components and weight targets.
+ """
+ assert stride[0] == stride[1], 'h stride is not equal to w stride!'
+ anchors = anchors.reshape(-1, 4)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(-1, 4 * (self.reg_max + 1))
+ soft_targets = soft_targets.permute(0, 2, 3,
+ 1).reshape(-1,
+ 4 * (self.reg_max + 1))
+
+ bbox_targets = bbox_targets.reshape(-1, 4)
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = ((labels >= 0)
+ & (labels < bg_class_ind)).nonzero().squeeze(1)
+ score = label_weights.new_zeros(labels.shape)
+
+ if len(pos_inds) > 0:
+ pos_bbox_targets = bbox_targets[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_anchors = anchors[pos_inds]
+ pos_anchor_centers = self.anchor_center(pos_anchors) / stride[0]
+
+ weight_targets = cls_score.detach().sigmoid()
+ weight_targets = weight_targets.max(dim=1)[0][pos_inds]
+ pos_bbox_pred_corners = self.integral(pos_bbox_pred)
+ pos_decode_bbox_pred = distance2bbox(pos_anchor_centers,
+ pos_bbox_pred_corners)
+ pos_decode_bbox_targets = pos_bbox_targets / stride[0]
+ score[pos_inds] = bbox_overlaps(
+ pos_decode_bbox_pred.detach(),
+ pos_decode_bbox_targets,
+ is_aligned=True)
+ pred_corners = pos_bbox_pred.reshape(-1, self.reg_max + 1)
+ pos_soft_targets = soft_targets[pos_inds]
+ soft_corners = pos_soft_targets.reshape(-1, self.reg_max + 1)
+
+ target_corners = bbox2distance(pos_anchor_centers,
+ pos_decode_bbox_targets,
+ self.reg_max).reshape(-1)
+
+ # regression loss
+ loss_bbox = self.loss_bbox(
+ pos_decode_bbox_pred,
+ pos_decode_bbox_targets,
+ weight=weight_targets,
+ avg_factor=1.0)
+
+ # dfl loss
+ loss_dfl = self.loss_dfl(
+ pred_corners,
+ target_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+
+ # ld loss
+ loss_ld = self.loss_ld(
+ pred_corners,
+ soft_corners,
+ weight=weight_targets[:, None].expand(-1, 4).reshape(-1),
+ avg_factor=4.0)
+
+ else:
+ loss_ld = bbox_pred.sum() * 0
+ loss_bbox = bbox_pred.sum() * 0
+ loss_dfl = bbox_pred.sum() * 0
+ weight_targets = bbox_pred.new_tensor(0)
+
+ # cls (qfl) loss
+ loss_cls = self.loss_cls(
+ cls_score, (labels, score),
+ weight=label_weights,
+ avg_factor=num_total_samples)
+
+ return loss_cls, loss_bbox, loss_dfl, loss_ld, weight_targets.sum()
+
+ def forward_train(self,
+ x,
+ out_teacher,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """
+ Args:
+ x (list[Tensor]): Features from FPN.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used
+
+ Returns:
+ tuple[dict, list]: The loss components and proposals of each image.
+
+ - losses (dict[str, Tensor]): A dictionary of loss components.
+ - proposal_list (list[Tensor]): Proposals of each image.
+ """
+ outs = self(x)
+ soft_target = out_teacher[1]
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, soft_target, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, soft_target, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ if proposal_cfg is None:
+ return losses
+ else:
+ proposal_list = self.get_bboxes(*outs, img_metas, cfg=proposal_cfg)
+ return losses, proposal_list
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ soft_target,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Cls and quality scores for each scale
+ level has shape (N, num_classes, H, W).
+ bbox_preds (list[Tensor]): Box distribution logits for each scale
+ level with shape (N, 4*(n+1), H, W), n is max value of integral
+ set.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_total_samples = reduce_mean(
+ torch.tensor(num_total_pos, dtype=torch.float,
+ device=device)).item()
+ num_total_samples = max(num_total_samples, 1.0)
+
+ losses_cls, losses_bbox, losses_dfl, losses_ld, \
+ avg_factor = multi_apply(
+ self.loss_single,
+ anchor_list,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ self.anchor_generator.strides,
+ soft_target,
+ num_total_samples=num_total_samples)
+
+ avg_factor = sum(avg_factor) + 1e-6
+ avg_factor = reduce_mean(avg_factor).item()
+ losses_bbox = [x / avg_factor for x in losses_bbox]
+ losses_dfl = [x / avg_factor for x in losses_dfl]
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox=losses_bbox,
+ loss_dfl=losses_dfl,
+ loss_ld=losses_ld)
diff --git a/mmdet/models/dense_heads/nasfcos_head.py b/mmdet/models/dense_heads/nasfcos_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..994ce0455e1982110f237b3958a81394c319bb47
--- /dev/null
+++ b/mmdet/models/dense_heads/nasfcos_head.py
@@ -0,0 +1,75 @@
+import copy
+
+import torch.nn as nn
+from mmcv.cnn import (ConvModule, Scale, bias_init_with_prob,
+ caffe2_xavier_init, normal_init)
+
+from mmdet.models.dense_heads.fcos_head import FCOSHead
+from ..builder import HEADS
+
+
+@HEADS.register_module()
+class NASFCOSHead(FCOSHead):
+ """Anchor-free head used in `NASFCOS `_.
+
+ It is quite similar with FCOS head, except for the searched structure of
+ classification branch and bbox regression branch, where a structure of
+ "dconv3x3, conv3x3, dconv3x3, conv1x1" is utilized instead.
+ """
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ dconv3x3_config = dict(
+ type='DCNv2',
+ kernel_size=3,
+ use_bias=True,
+ deform_groups=2,
+ padding=1)
+ conv3x3_config = dict(type='Conv', kernel_size=3, padding=1)
+ conv1x1_config = dict(type='Conv', kernel_size=1)
+
+ self.arch_config = [
+ dconv3x3_config, conv3x3_config, dconv3x3_config, conv1x1_config
+ ]
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i, op_ in enumerate(self.arch_config):
+ op = copy.deepcopy(op_)
+ chn = self.in_channels if i == 0 else self.feat_channels
+ assert isinstance(op, dict)
+ use_bias = op.pop('use_bias', False)
+ padding = op.pop('padding', 0)
+ kernel_size = op.pop('kernel_size')
+ module = ConvModule(
+ chn,
+ self.feat_channels,
+ kernel_size,
+ stride=1,
+ padding=padding,
+ norm_cfg=self.norm_cfg,
+ bias=use_bias,
+ conv_cfg=op)
+
+ self.cls_convs.append(copy.deepcopy(module))
+ self.reg_convs.append(copy.deepcopy(module))
+
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.conv_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.conv_centerness = nn.Conv2d(self.feat_channels, 1, 3, padding=1)
+
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ # retinanet_bias_init
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.conv_reg, std=0.01)
+ normal_init(self.conv_centerness, std=0.01)
+ normal_init(self.conv_cls, std=0.01, bias=bias_cls)
+
+ for branch in [self.cls_convs, self.reg_convs]:
+ for module in branch.modules():
+ if isinstance(module, ConvModule) \
+ and isinstance(module.conv, nn.Conv2d):
+ caffe2_xavier_init(module.conv)
diff --git a/mmdet/models/dense_heads/paa_head.py b/mmdet/models/dense_heads/paa_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e067b0121cf8b8230c0c9c6b8cfd41f56be4e298
--- /dev/null
+++ b/mmdet/models/dense_heads/paa_head.py
@@ -0,0 +1,671 @@
+import numpy as np
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import multi_apply, multiclass_nms
+from mmdet.core.bbox.iou_calculators import bbox_overlaps
+from mmdet.models import HEADS
+from mmdet.models.dense_heads import ATSSHead
+
+EPS = 1e-12
+try:
+ import sklearn.mixture as skm
+except ImportError:
+ skm = None
+
+
+def levels_to_images(mlvl_tensor):
+ """Concat multi-level feature maps by image.
+
+ [feature_level0, feature_level1...] -> [feature_image0, feature_image1...]
+ Convert the shape of each element in mlvl_tensor from (N, C, H, W) to
+ (N, H*W , C), then split the element to N elements with shape (H*W, C), and
+ concat elements in same image of all level along first dimension.
+
+ Args:
+ mlvl_tensor (list[torch.Tensor]): list of Tensor which collect from
+ corresponding level. Each element is of shape (N, C, H, W)
+
+ Returns:
+ list[torch.Tensor]: A list that contains N tensors and each tensor is
+ of shape (num_elements, C)
+ """
+ batch_size = mlvl_tensor[0].size(0)
+ batch_list = [[] for _ in range(batch_size)]
+ channels = mlvl_tensor[0].size(1)
+ for t in mlvl_tensor:
+ t = t.permute(0, 2, 3, 1)
+ t = t.view(batch_size, -1, channels).contiguous()
+ for img in range(batch_size):
+ batch_list[img].append(t[img])
+ return [torch.cat(item, 0) for item in batch_list]
+
+
+@HEADS.register_module()
+class PAAHead(ATSSHead):
+ """Head of PAAAssignment: Probabilistic Anchor Assignment with IoU
+ Prediction for Object Detection.
+
+ Code is modified from the `official github repo
+ `_.
+
+ More details can be found in the `paper
+ `_ .
+
+ Args:
+ topk (int): Select topk samples with smallest loss in
+ each level.
+ score_voting (bool): Whether to use score voting in post-process.
+ covariance_type : String describing the type of covariance parameters
+ to be used in :class:`sklearn.mixture.GaussianMixture`.
+ It must be one of:
+
+ - 'full': each component has its own general covariance matrix
+ - 'tied': all components share the same general covariance matrix
+ - 'diag': each component has its own diagonal covariance matrix
+ - 'spherical': each component has its own single variance
+ Default: 'diag'. From 'full' to 'spherical', the gmm fitting
+ process is faster yet the performance could be influenced. For most
+ cases, 'diag' should be a good choice.
+ """
+
+ def __init__(self,
+ *args,
+ topk=9,
+ score_voting=True,
+ covariance_type='diag',
+ **kwargs):
+ # topk used in paa reassign process
+ self.topk = topk
+ self.with_score_voting = score_voting
+ self.covariance_type = covariance_type
+ super(PAAHead, self).__init__(*args, **kwargs)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'iou_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ iou_preds (list[Tensor]): iou_preds for each scale
+ level with shape (N, num_anchors * 1, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor] | None): Specify which bounding
+ boxes can be ignored when are computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss gmm_assignment.
+ """
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ )
+ (labels, labels_weight, bboxes_target, bboxes_weight, pos_inds,
+ pos_gt_index) = cls_reg_targets
+ cls_scores = levels_to_images(cls_scores)
+ cls_scores = [
+ item.reshape(-1, self.cls_out_channels) for item in cls_scores
+ ]
+ bbox_preds = levels_to_images(bbox_preds)
+ bbox_preds = [item.reshape(-1, 4) for item in bbox_preds]
+ iou_preds = levels_to_images(iou_preds)
+ iou_preds = [item.reshape(-1, 1) for item in iou_preds]
+ pos_losses_list, = multi_apply(self.get_pos_loss, anchor_list,
+ cls_scores, bbox_preds, labels,
+ labels_weight, bboxes_target,
+ bboxes_weight, pos_inds)
+
+ with torch.no_grad():
+ reassign_labels, reassign_label_weight, \
+ reassign_bbox_weights, num_pos = multi_apply(
+ self.paa_reassign,
+ pos_losses_list,
+ labels,
+ labels_weight,
+ bboxes_weight,
+ pos_inds,
+ pos_gt_index,
+ anchor_list)
+ num_pos = sum(num_pos)
+ # convert all tensor list to a flatten tensor
+ cls_scores = torch.cat(cls_scores, 0).view(-1, cls_scores[0].size(-1))
+ bbox_preds = torch.cat(bbox_preds, 0).view(-1, bbox_preds[0].size(-1))
+ iou_preds = torch.cat(iou_preds, 0).view(-1, iou_preds[0].size(-1))
+ labels = torch.cat(reassign_labels, 0).view(-1)
+ flatten_anchors = torch.cat(
+ [torch.cat(item, 0) for item in anchor_list])
+ labels_weight = torch.cat(reassign_label_weight, 0).view(-1)
+ bboxes_target = torch.cat(bboxes_target,
+ 0).view(-1, bboxes_target[0].size(-1))
+
+ pos_inds_flatten = ((labels >= 0)
+ &
+ (labels < self.num_classes)).nonzero().reshape(-1)
+
+ losses_cls = self.loss_cls(
+ cls_scores,
+ labels,
+ labels_weight,
+ avg_factor=max(num_pos, len(img_metas))) # avoid num_pos=0
+ if num_pos:
+ pos_bbox_pred = self.bbox_coder.decode(
+ flatten_anchors[pos_inds_flatten],
+ bbox_preds[pos_inds_flatten])
+ pos_bbox_target = bboxes_target[pos_inds_flatten]
+ iou_target = bbox_overlaps(
+ pos_bbox_pred.detach(), pos_bbox_target, is_aligned=True)
+ losses_iou = self.loss_centerness(
+ iou_preds[pos_inds_flatten],
+ iou_target.unsqueeze(-1),
+ avg_factor=num_pos)
+ losses_bbox = self.loss_bbox(
+ pos_bbox_pred,
+ pos_bbox_target,
+ iou_target.clamp(min=EPS),
+ avg_factor=iou_target.sum())
+ else:
+ losses_iou = iou_preds.sum() * 0
+ losses_bbox = bbox_preds.sum() * 0
+
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox, loss_iou=losses_iou)
+
+ def get_pos_loss(self, anchors, cls_score, bbox_pred, label, label_weight,
+ bbox_target, bbox_weight, pos_inds):
+ """Calculate loss of all potential positive samples obtained from first
+ match process.
+
+ Args:
+ anchors (list[Tensor]): Anchors of each scale.
+ cls_score (Tensor): Box scores of single image with shape
+ (num_anchors, num_classes)
+ bbox_pred (Tensor): Box energies / deltas of single image
+ with shape (num_anchors, 4)
+ label (Tensor): classification target of each anchor with
+ shape (num_anchors,)
+ label_weight (Tensor): Classification loss weight of each
+ anchor with shape (num_anchors).
+ bbox_target (dict): Regression target of each anchor with
+ shape (num_anchors, 4).
+ bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ pos_inds (Tensor): Index of all positive samples got from
+ first assign process.
+
+ Returns:
+ Tensor: Losses of all positive samples in single image.
+ """
+ if not len(pos_inds):
+ return cls_score.new([]),
+ anchors_all_level = torch.cat(anchors, 0)
+ pos_scores = cls_score[pos_inds]
+ pos_bbox_pred = bbox_pred[pos_inds]
+ pos_label = label[pos_inds]
+ pos_label_weight = label_weight[pos_inds]
+ pos_bbox_target = bbox_target[pos_inds]
+ pos_bbox_weight = bbox_weight[pos_inds]
+ pos_anchors = anchors_all_level[pos_inds]
+ pos_bbox_pred = self.bbox_coder.decode(pos_anchors, pos_bbox_pred)
+
+ # to keep loss dimension
+ loss_cls = self.loss_cls(
+ pos_scores,
+ pos_label,
+ pos_label_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+
+ loss_bbox = self.loss_bbox(
+ pos_bbox_pred,
+ pos_bbox_target,
+ pos_bbox_weight,
+ avg_factor=self.loss_cls.loss_weight,
+ reduction_override='none')
+
+ loss_cls = loss_cls.sum(-1)
+ pos_loss = loss_bbox + loss_cls
+ return pos_loss,
+
+ def paa_reassign(self, pos_losses, label, label_weight, bbox_weight,
+ pos_inds, pos_gt_inds, anchors):
+ """Fit loss to GMM distribution and separate positive, ignore, negative
+ samples again with GMM model.
+
+ Args:
+ pos_losses (Tensor): Losses of all positive samples in
+ single image.
+ label (Tensor): classification target of each anchor with
+ shape (num_anchors,)
+ label_weight (Tensor): Classification loss weight of each
+ anchor with shape (num_anchors).
+ bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ pos_inds (Tensor): Index of all positive samples got from
+ first assign process.
+ pos_gt_inds (Tensor): Gt_index of all positive samples got
+ from first assign process.
+ anchors (list[Tensor]): Anchors of each scale.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - label (Tensor): classification target of each anchor after
+ paa assign, with shape (num_anchors,)
+ - label_weight (Tensor): Classification loss weight of each
+ anchor after paa assign, with shape (num_anchors).
+ - bbox_weight (Tensor): Bbox weight of each anchor with shape
+ (num_anchors, 4).
+ - num_pos (int): The number of positive samples after paa
+ assign.
+ """
+ if not len(pos_inds):
+ return label, label_weight, bbox_weight, 0
+ label = label.clone()
+ label_weight = label_weight.clone()
+ bbox_weight = bbox_weight.clone()
+ num_gt = pos_gt_inds.max() + 1
+ num_level = len(anchors)
+ num_anchors_each_level = [item.size(0) for item in anchors]
+ num_anchors_each_level.insert(0, 0)
+ inds_level_interval = np.cumsum(num_anchors_each_level)
+ pos_level_mask = []
+ for i in range(num_level):
+ mask = (pos_inds >= inds_level_interval[i]) & (
+ pos_inds < inds_level_interval[i + 1])
+ pos_level_mask.append(mask)
+ pos_inds_after_paa = [label.new_tensor([])]
+ ignore_inds_after_paa = [label.new_tensor([])]
+ for gt_ind in range(num_gt):
+ pos_inds_gmm = []
+ pos_loss_gmm = []
+ gt_mask = pos_gt_inds == gt_ind
+ for level in range(num_level):
+ level_mask = pos_level_mask[level]
+ level_gt_mask = level_mask & gt_mask
+ value, topk_inds = pos_losses[level_gt_mask].topk(
+ min(level_gt_mask.sum(), self.topk), largest=False)
+ pos_inds_gmm.append(pos_inds[level_gt_mask][topk_inds])
+ pos_loss_gmm.append(value)
+ pos_inds_gmm = torch.cat(pos_inds_gmm)
+ pos_loss_gmm = torch.cat(pos_loss_gmm)
+ # fix gmm need at least two sample
+ if len(pos_inds_gmm) < 2:
+ continue
+ device = pos_inds_gmm.device
+ pos_loss_gmm, sort_inds = pos_loss_gmm.sort()
+ pos_inds_gmm = pos_inds_gmm[sort_inds]
+ pos_loss_gmm = pos_loss_gmm.view(-1, 1).cpu().numpy()
+ min_loss, max_loss = pos_loss_gmm.min(), pos_loss_gmm.max()
+ means_init = np.array([min_loss, max_loss]).reshape(2, 1)
+ weights_init = np.array([0.5, 0.5])
+ precisions_init = np.array([1.0, 1.0]).reshape(2, 1, 1) # full
+ if self.covariance_type == 'spherical':
+ precisions_init = precisions_init.reshape(2)
+ elif self.covariance_type == 'diag':
+ precisions_init = precisions_init.reshape(2, 1)
+ elif self.covariance_type == 'tied':
+ precisions_init = np.array([[1.0]])
+ if skm is None:
+ raise ImportError('Please run "pip install sklearn" '
+ 'to install sklearn first.')
+ gmm = skm.GaussianMixture(
+ 2,
+ weights_init=weights_init,
+ means_init=means_init,
+ precisions_init=precisions_init,
+ covariance_type=self.covariance_type)
+ gmm.fit(pos_loss_gmm)
+ gmm_assignment = gmm.predict(pos_loss_gmm)
+ scores = gmm.score_samples(pos_loss_gmm)
+ gmm_assignment = torch.from_numpy(gmm_assignment).to(device)
+ scores = torch.from_numpy(scores).to(device)
+
+ pos_inds_temp, ignore_inds_temp = self.gmm_separation_scheme(
+ gmm_assignment, scores, pos_inds_gmm)
+ pos_inds_after_paa.append(pos_inds_temp)
+ ignore_inds_after_paa.append(ignore_inds_temp)
+
+ pos_inds_after_paa = torch.cat(pos_inds_after_paa)
+ ignore_inds_after_paa = torch.cat(ignore_inds_after_paa)
+ reassign_mask = (pos_inds.unsqueeze(1) != pos_inds_after_paa).all(1)
+ reassign_ids = pos_inds[reassign_mask]
+ label[reassign_ids] = self.num_classes
+ label_weight[ignore_inds_after_paa] = 0
+ bbox_weight[reassign_ids] = 0
+ num_pos = len(pos_inds_after_paa)
+ return label, label_weight, bbox_weight, num_pos
+
+ def gmm_separation_scheme(self, gmm_assignment, scores, pos_inds_gmm):
+ """A general separation scheme for gmm model.
+
+ It separates a GMM distribution of candidate samples into three
+ parts, 0 1 and uncertain areas, and you can implement other
+ separation schemes by rewriting this function.
+
+ Args:
+ gmm_assignment (Tensor): The prediction of GMM which is of shape
+ (num_samples,). The 0/1 value indicates the distribution
+ that each sample comes from.
+ scores (Tensor): The probability of sample coming from the
+ fit GMM distribution. The tensor is of shape (num_samples,).
+ pos_inds_gmm (Tensor): All the indexes of samples which are used
+ to fit GMM model. The tensor is of shape (num_samples,)
+
+ Returns:
+ tuple[Tensor]: The indices of positive and ignored samples.
+
+ - pos_inds_temp (Tensor): Indices of positive samples.
+ - ignore_inds_temp (Tensor): Indices of ignore samples.
+ """
+ # The implementation is (c) in Fig.3 in origin paper instead of (b).
+ # You can refer to issues such as
+ # https://github.com/kkhoot/PAA/issues/8 and
+ # https://github.com/kkhoot/PAA/issues/9.
+ fgs = gmm_assignment == 0
+ pos_inds_temp = fgs.new_tensor([], dtype=torch.long)
+ ignore_inds_temp = fgs.new_tensor([], dtype=torch.long)
+ if fgs.nonzero().numel():
+ _, pos_thr_ind = scores[fgs].topk(1)
+ pos_inds_temp = pos_inds_gmm[fgs][:pos_thr_ind + 1]
+ ignore_inds_temp = pos_inds_gmm.new_tensor([])
+ return pos_inds_temp, ignore_inds_temp
+
+ def get_targets(
+ self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=1,
+ unmap_outputs=True,
+ ):
+ """Get targets for PAA head.
+
+ This method is almost the same as `AnchorHead.get_targets()`. We direct
+ return the results from _get_targets_single instead map it to levels
+ by images_to_levels function.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, 4).
+ valid_flag_list (list[list[Tensor]]): Multi level valid flags of
+ each image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+
+ - labels (list[Tensor]): Labels of all anchors, each with
+ shape (num_anchors,).
+ - label_weights (list[Tensor]): Label weights of all anchor.
+ each with shape (num_anchors,).
+ - bbox_targets (list[Tensor]): BBox targets of all anchors.
+ each with shape (num_anchors, 4).
+ - bbox_weights (list[Tensor]): BBox weights of all anchors.
+ each with shape (num_anchors, 4).
+ - pos_inds (list[Tensor]): Contains all index of positive
+ sample in all anchor.
+ - gt_inds (list[Tensor]): Contains all gt_index of positive
+ sample in all anchor.
+ """
+
+ num_imgs = len(img_metas)
+ assert len(anchor_list) == len(valid_flag_list) == num_imgs
+ concat_anchor_list = []
+ concat_valid_flag_list = []
+ for i in range(num_imgs):
+ assert len(anchor_list[i]) == len(valid_flag_list[i])
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ concat_valid_flag_list.append(torch.cat(valid_flag_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ results = multi_apply(
+ self._get_targets_single,
+ concat_anchor_list,
+ concat_valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+
+ (labels, label_weights, bbox_targets, bbox_weights, valid_pos_inds,
+ valid_neg_inds, sampling_result) = results
+
+ # Due to valid flag of anchors, we have to calculate the real pos_inds
+ # in origin anchor set.
+ pos_inds = []
+ for i, single_labels in enumerate(labels):
+ pos_mask = (0 <= single_labels) & (
+ single_labels < self.num_classes)
+ pos_inds.append(pos_mask.nonzero().view(-1))
+
+ gt_inds = [item.pos_assigned_gt_inds for item in sampling_result]
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ gt_inds)
+
+ def _get_targets_single(self,
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ This method is same as `AnchorHead._get_targets_single()`.
+ """
+ assert unmap_outputs, 'We must map outputs back to the original' \
+ 'set of anchors in PAAhead'
+ return super(ATSSHead, self)._get_targets_single(
+ flat_anchors,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=1,
+ unmap_outputs=True)
+
+ def _get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ iou_preds,
+ mlvl_anchors,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into labeled boxes.
+
+ This method is almost same as `ATSSHead._get_bboxes()`.
+ We use sqrt(iou_preds * cls_scores) in NMS process instead of just
+ cls_scores. Besides, score voting is used when `` score_voting``
+ is set to True.
+ """
+ assert with_nms, 'PAA only supports "with_nms=True" now'
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_anchors)
+ batch_size = cls_scores[0].shape[0]
+
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_iou_preds = []
+ for cls_score, bbox_pred, iou_preds, anchors in zip(
+ cls_scores, bbox_preds, iou_preds, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+
+ scores = cls_score.permute(0, 2, 3, 1).reshape(
+ batch_size, -1, self.cls_out_channels).sigmoid()
+ bbox_pred = bbox_pred.permute(0, 2, 3,
+ 1).reshape(batch_size, -1, 4)
+ iou_preds = iou_preds.permute(0, 2, 3, 1).reshape(batch_size,
+ -1).sigmoid()
+
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[1] > nms_pre:
+ max_scores, _ = (scores * iou_preds[..., None]).sqrt().max(-1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ scores = scores[batch_inds, topk_inds, :]
+ iou_preds = iou_preds[batch_inds, topk_inds]
+ else:
+ anchors = anchors.expand_as(bbox_pred)
+
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shapes)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_iou_preds.append(iou_preds)
+
+ batch_mlvl_bboxes = torch.cat(mlvl_bboxes, dim=1)
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1], 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+ batch_mlvl_iou_preds = torch.cat(mlvl_iou_preds, dim=1)
+ batch_mlvl_nms_scores = (batch_mlvl_scores *
+ batch_mlvl_iou_preds[..., None]).sqrt()
+
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores) in zip(batch_mlvl_bboxes,
+ batch_mlvl_nms_scores):
+ det_bbox, det_label = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=None)
+ if self.with_score_voting and len(det_bbox) > 0:
+ det_bbox, det_label = self.score_voting(
+ det_bbox, det_label, mlvl_bboxes, mlvl_scores,
+ cfg.score_thr)
+ det_results.append(tuple([det_bbox, det_label]))
+
+ return det_results
+
+ def score_voting(self, det_bboxes, det_labels, mlvl_bboxes,
+ mlvl_nms_scores, score_thr):
+ """Implementation of score voting method works on each remaining boxes
+ after NMS procedure.
+
+ Args:
+ det_bboxes (Tensor): Remaining boxes after NMS procedure,
+ with shape (k, 5), each dimension means
+ (x1, y1, x2, y2, score).
+ det_labels (Tensor): The label of remaining boxes, with shape
+ (k, 1),Labels are 0-based.
+ mlvl_bboxes (Tensor): All boxes before the NMS procedure,
+ with shape (num_anchors,4).
+ mlvl_nms_scores (Tensor): The scores of all boxes which is used
+ in the NMS procedure, with shape (num_anchors, num_class)
+ mlvl_iou_preds (Tensor): The predictions of IOU of all boxes
+ before the NMS procedure, with shape (num_anchors, 1)
+ score_thr (float): The score threshold of bboxes.
+
+ Returns:
+ tuple: Usually returns a tuple containing voting results.
+
+ - det_bboxes_voted (Tensor): Remaining boxes after
+ score voting procedure, with shape (k, 5), each
+ dimension means (x1, y1, x2, y2, score).
+ - det_labels_voted (Tensor): Label of remaining bboxes
+ after voting, with shape (num_anchors,).
+ """
+ candidate_mask = mlvl_nms_scores > score_thr
+ candidate_mask_nonzeros = candidate_mask.nonzero()
+ candidate_inds = candidate_mask_nonzeros[:, 0]
+ candidate_labels = candidate_mask_nonzeros[:, 1]
+ candidate_bboxes = mlvl_bboxes[candidate_inds]
+ candidate_scores = mlvl_nms_scores[candidate_mask]
+ det_bboxes_voted = []
+ det_labels_voted = []
+ for cls in range(self.cls_out_channels):
+ candidate_cls_mask = candidate_labels == cls
+ if not candidate_cls_mask.any():
+ continue
+ candidate_cls_scores = candidate_scores[candidate_cls_mask]
+ candidate_cls_bboxes = candidate_bboxes[candidate_cls_mask]
+ det_cls_mask = det_labels == cls
+ det_cls_bboxes = det_bboxes[det_cls_mask].view(
+ -1, det_bboxes.size(-1))
+ det_candidate_ious = bbox_overlaps(det_cls_bboxes[:, :4],
+ candidate_cls_bboxes)
+ for det_ind in range(len(det_cls_bboxes)):
+ single_det_ious = det_candidate_ious[det_ind]
+ pos_ious_mask = single_det_ious > 0.01
+ pos_ious = single_det_ious[pos_ious_mask]
+ pos_bboxes = candidate_cls_bboxes[pos_ious_mask]
+ pos_scores = candidate_cls_scores[pos_ious_mask]
+ pis = (torch.exp(-(1 - pos_ious)**2 / 0.025) *
+ pos_scores)[:, None]
+ voted_box = torch.sum(
+ pis * pos_bboxes, dim=0) / torch.sum(
+ pis, dim=0)
+ voted_score = det_cls_bboxes[det_ind][-1:][None, :]
+ det_bboxes_voted.append(
+ torch.cat((voted_box[None, :], voted_score), dim=1))
+ det_labels_voted.append(cls)
+
+ det_bboxes_voted = torch.cat(det_bboxes_voted, dim=0)
+ det_labels_voted = det_labels.new_tensor(det_labels_voted)
+ return det_bboxes_voted, det_labels_voted
diff --git a/mmdet/models/dense_heads/pisa_retinanet_head.py b/mmdet/models/dense_heads/pisa_retinanet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd87b9aeb07e05ff94b444ac8999eca3f616711a
--- /dev/null
+++ b/mmdet/models/dense_heads/pisa_retinanet_head.py
@@ -0,0 +1,154 @@
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.core import images_to_levels
+from ..builder import HEADS
+from ..losses import carl_loss, isr_p
+from .retina_head import RetinaHead
+
+
+@HEADS.register_module()
+class PISARetinaHead(RetinaHead):
+ """PISA Retinanet Head.
+
+ The head owns the same structure with Retinanet Head, but differs in two
+ aspects:
+ 1. Importance-based Sample Reweighting Positive (ISR-P) is applied to
+ change the positive loss weights.
+ 2. Classification-aware regression loss is adopted as a third loss.
+ """
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image
+ with shape (num_obj, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each image
+ with shape (num_obj, 4).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
+ Default: None.
+
+ Returns:
+ dict: Loss dict, comprise classification loss, regression loss and
+ carl loss.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+
+ num_imgs = len(img_metas)
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3, 1).reshape(num_imgs, -1, label_channels)
+ for cls_score in cls_scores
+ ]
+ flatten_cls_scores = torch.cat(
+ flatten_cls_scores, dim=1).reshape(-1,
+ flatten_cls_scores[0].size(-1))
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(num_imgs, -1, 4)
+ for bbox_pred in bbox_preds
+ ]
+ flatten_bbox_preds = torch.cat(
+ flatten_bbox_preds, dim=1).view(-1, flatten_bbox_preds[0].size(-1))
+ flatten_labels = torch.cat(labels_list, dim=1).reshape(-1)
+ flatten_label_weights = torch.cat(
+ label_weights_list, dim=1).reshape(-1)
+ flatten_anchors = torch.cat(all_anchor_list, dim=1).reshape(-1, 4)
+ flatten_bbox_targets = torch.cat(
+ bbox_targets_list, dim=1).reshape(-1, 4)
+ flatten_bbox_weights = torch.cat(
+ bbox_weights_list, dim=1).reshape(-1, 4)
+
+ # Apply ISR-P
+ isr_cfg = self.train_cfg.get('isr', None)
+ if isr_cfg is not None:
+ all_targets = (flatten_labels, flatten_label_weights,
+ flatten_bbox_targets, flatten_bbox_weights)
+ with torch.no_grad():
+ all_targets = isr_p(
+ flatten_cls_scores,
+ flatten_bbox_preds,
+ all_targets,
+ flatten_anchors,
+ sampling_results_list,
+ bbox_coder=self.bbox_coder,
+ loss_cls=self.loss_cls,
+ num_class=self.num_classes,
+ **self.train_cfg.isr)
+ (flatten_labels, flatten_label_weights, flatten_bbox_targets,
+ flatten_bbox_weights) = all_targets
+
+ # For convenience we compute loss once instead separating by fpn level,
+ # so that we don't need to separate the weights by level again.
+ # The result should be the same
+ losses_cls = self.loss_cls(
+ flatten_cls_scores,
+ flatten_labels,
+ flatten_label_weights,
+ avg_factor=num_total_samples)
+ losses_bbox = self.loss_bbox(
+ flatten_bbox_preds,
+ flatten_bbox_targets,
+ flatten_bbox_weights,
+ avg_factor=num_total_samples)
+ loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+
+ # CARL Loss
+ carl_cfg = self.train_cfg.get('carl', None)
+ if carl_cfg is not None:
+ loss_carl = carl_loss(
+ flatten_cls_scores,
+ flatten_labels,
+ flatten_bbox_preds,
+ flatten_bbox_targets,
+ self.loss_bbox,
+ **self.train_cfg.carl,
+ avg_factor=num_total_pos,
+ sigmoid=True,
+ num_class=self.num_classes)
+ loss_dict.update(loss_carl)
+
+ return loss_dict
diff --git a/mmdet/models/dense_heads/pisa_ssd_head.py b/mmdet/models/dense_heads/pisa_ssd_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ef3c83ed62d8346c8daef01f18ad7bd236623c
--- /dev/null
+++ b/mmdet/models/dense_heads/pisa_ssd_head.py
@@ -0,0 +1,139 @@
+import torch
+
+from mmdet.core import multi_apply
+from ..builder import HEADS
+from ..losses import CrossEntropyLoss, SmoothL1Loss, carl_loss, isr_p
+from .ssd_head import SSDHead
+
+
+# TODO: add loss evaluator for SSD
+@HEADS.register_module()
+class PISASSDHead(SSDHead):
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image
+ with shape (num_obj, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each image
+ with shape (num_obj, 4).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (list[Tensor]): Ignored gt bboxes of each image.
+ Default: None.
+
+ Returns:
+ dict: Loss dict, comprise classification loss regression loss and
+ carl loss.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=1,
+ unmap_outputs=False,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results_list) = cls_reg_targets
+
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ isr_cfg = self.train_cfg.get('isr', None)
+ all_targets = (all_labels.view(-1), all_label_weights.view(-1),
+ all_bbox_targets.view(-1,
+ 4), all_bbox_weights.view(-1, 4))
+ # apply ISR-P
+ if isr_cfg is not None:
+ all_targets = isr_p(
+ all_cls_scores.view(-1, all_cls_scores.size(-1)),
+ all_bbox_preds.view(-1, 4),
+ all_targets,
+ torch.cat(all_anchors),
+ sampling_results_list,
+ loss_cls=CrossEntropyLoss(),
+ bbox_coder=self.bbox_coder,
+ **self.train_cfg.isr,
+ num_class=self.num_classes)
+ (new_labels, new_label_weights, new_bbox_targets,
+ new_bbox_weights) = all_targets
+ all_labels = new_labels.view(all_labels.shape)
+ all_label_weights = new_label_weights.view(all_label_weights.shape)
+ all_bbox_targets = new_bbox_targets.view(all_bbox_targets.shape)
+ all_bbox_weights = new_bbox_weights.view(all_bbox_weights.shape)
+
+ # add CARL loss
+ carl_loss_cfg = self.train_cfg.get('carl', None)
+ if carl_loss_cfg is not None:
+ loss_carl = carl_loss(
+ all_cls_scores.view(-1, all_cls_scores.size(-1)),
+ all_targets[0],
+ all_bbox_preds.view(-1, 4),
+ all_targets[2],
+ SmoothL1Loss(beta=1.),
+ **self.train_cfg.carl,
+ avg_factor=num_total_pos,
+ num_class=self.num_classes)
+
+ # check NaN and Inf
+ assert torch.isfinite(all_cls_scores).all().item(), \
+ 'classification scores become infinite or NaN!'
+ assert torch.isfinite(all_bbox_preds).all().item(), \
+ 'bbox predications become infinite or NaN!'
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ loss_dict = dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
+ if carl_loss_cfg is not None:
+ loss_dict.update(loss_carl)
+ return loss_dict
diff --git a/mmdet/models/dense_heads/reppoints_head.py b/mmdet/models/dense_heads/reppoints_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..499cc4f71c968704a40ab2bb7a6b22dd079d82de
--- /dev/null
+++ b/mmdet/models/dense_heads/reppoints_head.py
@@ -0,0 +1,763 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+from mmcv.ops import DeformConv2d
+
+from mmdet.core import (PointGenerator, build_assigner, build_sampler,
+ images_to_levels, multi_apply, multiclass_nms, unmap)
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+
+@HEADS.register_module()
+class RepPointsHead(AnchorFreeHead):
+ """RepPoint head.
+
+ Args:
+ point_feat_channels (int): Number of channels of points features.
+ gradient_mul (float): The multiplier to gradients from
+ points refinement and recognition.
+ point_strides (Iterable): points strides.
+ point_base_scale (int): bbox scale for assigning labels.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_init (dict): Config of initial points loss.
+ loss_bbox_refine (dict): Config of points loss in refinement.
+ use_grid_points (bool): If we use bounding box representation, the
+ reppoints is represented as grid points on the bounding box.
+ center_init (bool): Whether to use center point assignment.
+ transform_method (str): The methods to transform RepPoints to bbox.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ point_feat_channels=256,
+ num_points=9,
+ gradient_mul=0.1,
+ point_strides=[8, 16, 32, 64, 128],
+ point_base_scale=4,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_init=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=0.5),
+ loss_bbox_refine=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.0),
+ use_grid_points=False,
+ center_init=True,
+ transform_method='moment',
+ moment_mul=0.01,
+ **kwargs):
+ self.num_points = num_points
+ self.point_feat_channels = point_feat_channels
+ self.use_grid_points = use_grid_points
+ self.center_init = center_init
+
+ # we use deform conv to extract points features
+ self.dcn_kernel = int(np.sqrt(num_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ assert self.dcn_kernel * self.dcn_kernel == num_points, \
+ 'The points number should be a square number.'
+ assert self.dcn_kernel % 2 == 1, \
+ 'The points number should be an odd square number.'
+ dcn_base = np.arange(-self.dcn_pad,
+ self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
+ (-1))
+ self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
+
+ super().__init__(num_classes, in_channels, loss_cls=loss_cls, **kwargs)
+
+ self.gradient_mul = gradient_mul
+ self.point_base_scale = point_base_scale
+ self.point_strides = point_strides
+ self.point_generators = [PointGenerator() for _ in self.point_strides]
+
+ self.sampling = loss_cls['type'] not in ['FocalLoss']
+ if self.train_cfg:
+ self.init_assigner = build_assigner(self.train_cfg.init.assigner)
+ self.refine_assigner = build_assigner(
+ self.train_cfg.refine.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.transform_method = transform_method
+ if self.transform_method == 'moment':
+ self.moment_transfer = nn.Parameter(
+ data=torch.zeros(2), requires_grad=True)
+ self.moment_mul = moment_mul
+
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = self.num_classes
+ else:
+ self.cls_out_channels = self.num_classes + 1
+ self.loss_bbox_init = build_loss(loss_bbox_init)
+ self.loss_bbox_refine = build_loss(loss_bbox_refine)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ pts_out_dim = 4 if self.use_grid_points else 2 * self.num_points
+ self.reppoints_cls_conv = DeformConv2d(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_cls_out = nn.Conv2d(self.point_feat_channels,
+ self.cls_out_channels, 1, 1, 0)
+ self.reppoints_pts_init_conv = nn.Conv2d(self.feat_channels,
+ self.point_feat_channels, 3,
+ 1, 1)
+ self.reppoints_pts_init_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+ self.reppoints_pts_refine_conv = DeformConv2d(self.feat_channels,
+ self.point_feat_channels,
+ self.dcn_kernel, 1,
+ self.dcn_pad)
+ self.reppoints_pts_refine_out = nn.Conv2d(self.point_feat_channels,
+ pts_out_dim, 1, 1, 0)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.reppoints_cls_conv, std=0.01)
+ normal_init(self.reppoints_cls_out, std=0.01, bias=bias_cls)
+ normal_init(self.reppoints_pts_init_conv, std=0.01)
+ normal_init(self.reppoints_pts_init_out, std=0.01)
+ normal_init(self.reppoints_pts_refine_conv, std=0.01)
+ normal_init(self.reppoints_pts_refine_out, std=0.01)
+
+ def points2bbox(self, pts, y_first=True):
+ """Converting the points set into bounding box.
+
+ :param pts: the input points sets (fields), each points
+ set (fields) is represented as 2n scalar.
+ :param y_first: if y_first=True, the point set is represented as
+ [y1, x1, y2, x2 ... yn, xn], otherwise the point set is
+ represented as [x1, y1, x2, y2 ... xn, yn].
+ :return: each points set is converting to a bbox [x1, y1, x2, y2].
+ """
+ pts_reshape = pts.view(pts.shape[0], -1, 2, *pts.shape[2:])
+ pts_y = pts_reshape[:, :, 0, ...] if y_first else pts_reshape[:, :, 1,
+ ...]
+ pts_x = pts_reshape[:, :, 1, ...] if y_first else pts_reshape[:, :, 0,
+ ...]
+ if self.transform_method == 'minmax':
+ bbox_left = pts_x.min(dim=1, keepdim=True)[0]
+ bbox_right = pts_x.max(dim=1, keepdim=True)[0]
+ bbox_up = pts_y.min(dim=1, keepdim=True)[0]
+ bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
+ bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
+ dim=1)
+ elif self.transform_method == 'partial_minmax':
+ pts_y = pts_y[:, :4, ...]
+ pts_x = pts_x[:, :4, ...]
+ bbox_left = pts_x.min(dim=1, keepdim=True)[0]
+ bbox_right = pts_x.max(dim=1, keepdim=True)[0]
+ bbox_up = pts_y.min(dim=1, keepdim=True)[0]
+ bbox_bottom = pts_y.max(dim=1, keepdim=True)[0]
+ bbox = torch.cat([bbox_left, bbox_up, bbox_right, bbox_bottom],
+ dim=1)
+ elif self.transform_method == 'moment':
+ pts_y_mean = pts_y.mean(dim=1, keepdim=True)
+ pts_x_mean = pts_x.mean(dim=1, keepdim=True)
+ pts_y_std = torch.std(pts_y - pts_y_mean, dim=1, keepdim=True)
+ pts_x_std = torch.std(pts_x - pts_x_mean, dim=1, keepdim=True)
+ moment_transfer = (self.moment_transfer * self.moment_mul) + (
+ self.moment_transfer.detach() * (1 - self.moment_mul))
+ moment_width_transfer = moment_transfer[0]
+ moment_height_transfer = moment_transfer[1]
+ half_width = pts_x_std * torch.exp(moment_width_transfer)
+ half_height = pts_y_std * torch.exp(moment_height_transfer)
+ bbox = torch.cat([
+ pts_x_mean - half_width, pts_y_mean - half_height,
+ pts_x_mean + half_width, pts_y_mean + half_height
+ ],
+ dim=1)
+ else:
+ raise NotImplementedError
+ return bbox
+
+ def gen_grid_from_reg(self, reg, previous_boxes):
+ """Base on the previous bboxes and regression values, we compute the
+ regressed bboxes and generate the grids on the bboxes.
+
+ :param reg: the regression value to previous bboxes.
+ :param previous_boxes: previous bboxes.
+ :return: generate grids on the regressed bboxes.
+ """
+ b, _, h, w = reg.shape
+ bxy = (previous_boxes[:, :2, ...] + previous_boxes[:, 2:, ...]) / 2.
+ bwh = (previous_boxes[:, 2:, ...] -
+ previous_boxes[:, :2, ...]).clamp(min=1e-6)
+ grid_topleft = bxy + bwh * reg[:, :2, ...] - 0.5 * bwh * torch.exp(
+ reg[:, 2:, ...])
+ grid_wh = bwh * torch.exp(reg[:, 2:, ...])
+ grid_left = grid_topleft[:, [0], ...]
+ grid_top = grid_topleft[:, [1], ...]
+ grid_width = grid_wh[:, [0], ...]
+ grid_height = grid_wh[:, [1], ...]
+ intervel = torch.linspace(0., 1., self.dcn_kernel).view(
+ 1, self.dcn_kernel, 1, 1).type_as(reg)
+ grid_x = grid_left + grid_width * intervel
+ grid_x = grid_x.unsqueeze(1).repeat(1, self.dcn_kernel, 1, 1, 1)
+ grid_x = grid_x.view(b, -1, h, w)
+ grid_y = grid_top + grid_height * intervel
+ grid_y = grid_y.unsqueeze(2).repeat(1, 1, self.dcn_kernel, 1, 1)
+ grid_y = grid_y.view(b, -1, h, w)
+ grid_yx = torch.stack([grid_y, grid_x], dim=2)
+ grid_yx = grid_yx.view(b, -1, h, w)
+ regressed_bbox = torch.cat([
+ grid_left, grid_top, grid_left + grid_width, grid_top + grid_height
+ ], 1)
+ return grid_yx, regressed_bbox
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def forward_single(self, x):
+ """Forward feature map of a single FPN level."""
+ dcn_base_offset = self.dcn_base_offset.type_as(x)
+ # If we use center_init, the initial reppoints is from center points.
+ # If we use bounding bbox representation, the initial reppoints is
+ # from regular grid placed on a pre-defined bbox.
+ if self.use_grid_points or not self.center_init:
+ scale = self.point_base_scale / 2
+ points_init = dcn_base_offset / dcn_base_offset.max() * scale
+ bbox_init = x.new_tensor([-scale, -scale, scale,
+ scale]).view(1, 4, 1, 1)
+ else:
+ points_init = 0
+ cls_feat = x
+ pts_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ pts_feat = reg_conv(pts_feat)
+ # initialize reppoints
+ pts_out_init = self.reppoints_pts_init_out(
+ self.relu(self.reppoints_pts_init_conv(pts_feat)))
+ if self.use_grid_points:
+ pts_out_init, bbox_out_init = self.gen_grid_from_reg(
+ pts_out_init, bbox_init.detach())
+ else:
+ pts_out_init = pts_out_init + points_init
+ # refine and classify reppoints
+ pts_out_init_grad_mul = (1 - self.gradient_mul) * pts_out_init.detach(
+ ) + self.gradient_mul * pts_out_init
+ dcn_offset = pts_out_init_grad_mul - dcn_base_offset
+ cls_out = self.reppoints_cls_out(
+ self.relu(self.reppoints_cls_conv(cls_feat, dcn_offset)))
+ pts_out_refine = self.reppoints_pts_refine_out(
+ self.relu(self.reppoints_pts_refine_conv(pts_feat, dcn_offset)))
+ if self.use_grid_points:
+ pts_out_refine, bbox_out_refine = self.gen_grid_from_reg(
+ pts_out_refine, bbox_out_init.detach())
+ else:
+ pts_out_refine = pts_out_refine + pts_out_init.detach()
+ return cls_out, pts_out_init, pts_out_refine
+
+ def get_points(self, featmap_sizes, img_metas, device):
+ """Get points according to feature map sizes.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+
+ Returns:
+ tuple: points of each image, valid flags of each image
+ """
+ num_imgs = len(img_metas)
+ num_levels = len(featmap_sizes)
+
+ # since feature map sizes of all images are the same, we only compute
+ # points center for one time
+ multi_level_points = []
+ for i in range(num_levels):
+ points = self.point_generators[i].grid_points(
+ featmap_sizes[i], self.point_strides[i], device)
+ multi_level_points.append(points)
+ points_list = [[point.clone() for point in multi_level_points]
+ for _ in range(num_imgs)]
+
+ # for each image, we compute valid flags of multi level grids
+ valid_flag_list = []
+ for img_id, img_meta in enumerate(img_metas):
+ multi_level_flags = []
+ for i in range(num_levels):
+ point_stride = self.point_strides[i]
+ feat_h, feat_w = featmap_sizes[i]
+ h, w = img_meta['pad_shape'][:2]
+ valid_feat_h = min(int(np.ceil(h / point_stride)), feat_h)
+ valid_feat_w = min(int(np.ceil(w / point_stride)), feat_w)
+ flags = self.point_generators[i].valid_flags(
+ (feat_h, feat_w), (valid_feat_h, valid_feat_w), device)
+ multi_level_flags.append(flags)
+ valid_flag_list.append(multi_level_flags)
+
+ return points_list, valid_flag_list
+
+ def centers_to_bboxes(self, point_list):
+ """Get bboxes according to center points.
+
+ Only used in :class:`MaxIoUAssigner`.
+ """
+ bbox_list = []
+ for i_img, point in enumerate(point_list):
+ bbox = []
+ for i_lvl in range(len(self.point_strides)):
+ scale = self.point_base_scale * self.point_strides[i_lvl] * 0.5
+ bbox_shift = torch.Tensor([-scale, -scale, scale,
+ scale]).view(1, 4).type_as(point[0])
+ bbox_center = torch.cat(
+ [point[i_lvl][:, :2], point[i_lvl][:, :2]], dim=1)
+ bbox.append(bbox_center + bbox_shift)
+ bbox_list.append(bbox)
+ return bbox_list
+
+ def offset_to_pts(self, center_list, pred_list):
+ """Change from point offset to point coordinate."""
+ pts_list = []
+ for i_lvl in range(len(self.point_strides)):
+ pts_lvl = []
+ for i_img in range(len(center_list)):
+ pts_center = center_list[i_img][i_lvl][:, :2].repeat(
+ 1, self.num_points)
+ pts_shift = pred_list[i_lvl][i_img]
+ yx_pts_shift = pts_shift.permute(1, 2, 0).view(
+ -1, 2 * self.num_points)
+ y_pts_shift = yx_pts_shift[..., 0::2]
+ x_pts_shift = yx_pts_shift[..., 1::2]
+ xy_pts_shift = torch.stack([x_pts_shift, y_pts_shift], -1)
+ xy_pts_shift = xy_pts_shift.view(*yx_pts_shift.shape[:-1], -1)
+ pts = xy_pts_shift * self.point_strides[i_lvl] + pts_center
+ pts_lvl.append(pts)
+ pts_lvl = torch.stack(pts_lvl, 0)
+ pts_list.append(pts_lvl)
+ return pts_list
+
+ def _point_target_single(self,
+ flat_proposals,
+ valid_flags,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ label_channels=1,
+ stage='init',
+ unmap_outputs=True):
+ inside_flags = valid_flags
+ if not inside_flags.any():
+ return (None, ) * 7
+ # assign gt and sample proposals
+ proposals = flat_proposals[inside_flags, :]
+
+ if stage == 'init':
+ assigner = self.init_assigner
+ pos_weight = self.train_cfg.init.pos_weight
+ else:
+ assigner = self.refine_assigner
+ pos_weight = self.train_cfg.refine.pos_weight
+ assign_result = assigner.assign(proposals, gt_bboxes, gt_bboxes_ignore,
+ None if self.sampling else gt_labels)
+ sampling_result = self.sampler.sample(assign_result, proposals,
+ gt_bboxes)
+
+ num_valid_proposals = proposals.shape[0]
+ bbox_gt = proposals.new_zeros([num_valid_proposals, 4])
+ pos_proposals = torch.zeros_like(proposals)
+ proposals_weights = proposals.new_zeros([num_valid_proposals, 4])
+ labels = proposals.new_full((num_valid_proposals, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = proposals.new_zeros(
+ num_valid_proposals, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ pos_gt_bboxes = sampling_result.pos_gt_bboxes
+ bbox_gt[pos_inds, :] = pos_gt_bboxes
+ pos_proposals[pos_inds, :] = proposals[pos_inds, :]
+ proposals_weights[pos_inds, :] = 1.0
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of proposals
+ if unmap_outputs:
+ num_total_proposals = flat_proposals.size(0)
+ labels = unmap(labels, num_total_proposals, inside_flags)
+ label_weights = unmap(label_weights, num_total_proposals,
+ inside_flags)
+ bbox_gt = unmap(bbox_gt, num_total_proposals, inside_flags)
+ pos_proposals = unmap(pos_proposals, num_total_proposals,
+ inside_flags)
+ proposals_weights = unmap(proposals_weights, num_total_proposals,
+ inside_flags)
+
+ return (labels, label_weights, bbox_gt, pos_proposals,
+ proposals_weights, pos_inds, neg_inds)
+
+ def get_targets(self,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ stage='init',
+ label_channels=1,
+ unmap_outputs=True):
+ """Compute corresponding GT box and classification targets for
+ proposals.
+
+ Args:
+ proposals_list (list[list]): Multi level points/bboxes of each
+ image.
+ valid_flag_list (list[list]): Multi level valid flags of each
+ image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): Ground truth bboxes to be
+ ignored.
+ gt_bboxes_list (list[Tensor]): Ground truth labels of each box.
+ stage (str): `init` or `refine`. Generate target for init stage or
+ refine stage
+ label_channels (int): Channel of label.
+ unmap_outputs (bool): Whether to map outputs back to the original
+ set of anchors.
+
+ Returns:
+ tuple:
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each level. # noqa: E501
+ - bbox_gt_list (list[Tensor]): Ground truth bbox of each level.
+ - proposal_list (list[Tensor]): Proposals(points/bboxes) of each level. # noqa: E501
+ - proposal_weights_list (list[Tensor]): Proposal weights of each level. # noqa: E501
+ - num_total_pos (int): Number of positive samples in all images. # noqa: E501
+ - num_total_neg (int): Number of negative samples in all images. # noqa: E501
+ """
+ assert stage in ['init', 'refine']
+ num_imgs = len(img_metas)
+ assert len(proposals_list) == len(valid_flag_list) == num_imgs
+
+ # points number of multi levels
+ num_level_proposals = [points.size(0) for points in proposals_list[0]]
+
+ # concat all level points and flags to a single tensor
+ for i in range(num_imgs):
+ assert len(proposals_list[i]) == len(valid_flag_list[i])
+ proposals_list[i] = torch.cat(proposals_list[i])
+ valid_flag_list[i] = torch.cat(valid_flag_list[i])
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_gt, all_proposals,
+ all_proposal_weights, pos_inds_list, neg_inds_list) = multi_apply(
+ self._point_target_single,
+ proposals_list,
+ valid_flag_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ stage=stage,
+ label_channels=label_channels,
+ unmap_outputs=unmap_outputs)
+ # no valid points
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled points of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ labels_list = images_to_levels(all_labels, num_level_proposals)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_proposals)
+ bbox_gt_list = images_to_levels(all_bbox_gt, num_level_proposals)
+ proposals_list = images_to_levels(all_proposals, num_level_proposals)
+ proposal_weights_list = images_to_levels(all_proposal_weights,
+ num_level_proposals)
+ return (labels_list, label_weights_list, bbox_gt_list, proposals_list,
+ proposal_weights_list, num_total_pos, num_total_neg)
+
+ def loss_single(self, cls_score, pts_pred_init, pts_pred_refine, labels,
+ label_weights, bbox_gt_init, bbox_weights_init,
+ bbox_gt_refine, bbox_weights_refine, stride,
+ num_total_samples_init, num_total_samples_refine):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ cls_score = cls_score.contiguous()
+ loss_cls = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=num_total_samples_refine)
+
+ # points loss
+ bbox_gt_init = bbox_gt_init.reshape(-1, 4)
+ bbox_weights_init = bbox_weights_init.reshape(-1, 4)
+ bbox_pred_init = self.points2bbox(
+ pts_pred_init.reshape(-1, 2 * self.num_points), y_first=False)
+ bbox_gt_refine = bbox_gt_refine.reshape(-1, 4)
+ bbox_weights_refine = bbox_weights_refine.reshape(-1, 4)
+ bbox_pred_refine = self.points2bbox(
+ pts_pred_refine.reshape(-1, 2 * self.num_points), y_first=False)
+ normalize_term = self.point_base_scale * stride
+ loss_pts_init = self.loss_bbox_init(
+ bbox_pred_init / normalize_term,
+ bbox_gt_init / normalize_term,
+ bbox_weights_init,
+ avg_factor=num_total_samples_init)
+ loss_pts_refine = self.loss_bbox_refine(
+ bbox_pred_refine / normalize_term,
+ bbox_gt_refine / normalize_term,
+ bbox_weights_refine,
+ avg_factor=num_total_samples_refine)
+ return loss_cls, loss_pts_init, loss_pts_refine
+
+ def loss(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == len(self.point_generators)
+ device = cls_scores[0].device
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ # target for initial stage
+ center_list, valid_flag_list = self.get_points(featmap_sizes,
+ img_metas, device)
+ pts_coordinate_preds_init = self.offset_to_pts(center_list,
+ pts_preds_init)
+ if self.train_cfg.init.assigner['type'] == 'PointAssigner':
+ # Assign target for center list
+ candidate_list = center_list
+ else:
+ # transform center list to bbox list and
+ # assign target for bbox list
+ bbox_list = self.centers_to_bboxes(center_list)
+ candidate_list = bbox_list
+ cls_reg_targets_init = self.get_targets(
+ candidate_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='init',
+ label_channels=label_channels)
+ (*_, bbox_gt_list_init, candidate_list_init, bbox_weights_list_init,
+ num_total_pos_init, num_total_neg_init) = cls_reg_targets_init
+ num_total_samples_init = (
+ num_total_pos_init +
+ num_total_neg_init if self.sampling else num_total_pos_init)
+
+ # target for refinement stage
+ center_list, valid_flag_list = self.get_points(featmap_sizes,
+ img_metas, device)
+ pts_coordinate_preds_refine = self.offset_to_pts(
+ center_list, pts_preds_refine)
+ bbox_list = []
+ for i_img, center in enumerate(center_list):
+ bbox = []
+ for i_lvl in range(len(pts_preds_refine)):
+ bbox_preds_init = self.points2bbox(
+ pts_preds_init[i_lvl].detach())
+ bbox_shift = bbox_preds_init * self.point_strides[i_lvl]
+ bbox_center = torch.cat(
+ [center[i_lvl][:, :2], center[i_lvl][:, :2]], dim=1)
+ bbox.append(bbox_center +
+ bbox_shift[i_img].permute(1, 2, 0).reshape(-1, 4))
+ bbox_list.append(bbox)
+ cls_reg_targets_refine = self.get_targets(
+ bbox_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ stage='refine',
+ label_channels=label_channels)
+ (labels_list, label_weights_list, bbox_gt_list_refine,
+ candidate_list_refine, bbox_weights_list_refine, num_total_pos_refine,
+ num_total_neg_refine) = cls_reg_targets_refine
+ num_total_samples_refine = (
+ num_total_pos_refine +
+ num_total_neg_refine if self.sampling else num_total_pos_refine)
+
+ # compute loss
+ losses_cls, losses_pts_init, losses_pts_refine = multi_apply(
+ self.loss_single,
+ cls_scores,
+ pts_coordinate_preds_init,
+ pts_coordinate_preds_refine,
+ labels_list,
+ label_weights_list,
+ bbox_gt_list_init,
+ bbox_weights_list_init,
+ bbox_gt_list_refine,
+ bbox_weights_list_refine,
+ self.point_strides,
+ num_total_samples_init=num_total_samples_init,
+ num_total_samples_refine=num_total_samples_refine)
+ loss_dict_all = {
+ 'loss_cls': losses_cls,
+ 'loss_pts_init': losses_pts_init,
+ 'loss_pts_refine': losses_pts_refine
+ }
+ return loss_dict_all
+
+ def get_bboxes(self,
+ cls_scores,
+ pts_preds_init,
+ pts_preds_refine,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ assert len(cls_scores) == len(pts_preds_refine)
+ device = cls_scores[0].device
+ bbox_preds_refine = [
+ self.points2bbox(pts_pred_refine)
+ for pts_pred_refine in pts_preds_refine
+ ]
+ num_levels = len(cls_scores)
+ mlvl_points = [
+ self.point_generators[i].grid_points(cls_scores[i].size()[-2:],
+ self.point_strides[i], device)
+ for i in range(num_levels)
+ ]
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds_refine[i][img_id].detach()
+ for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ mlvl_points, img_shape,
+ scale_factor, cfg, rescale,
+ with_nms)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_points,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for i_lvl, (cls_score, bbox_pred, points) in enumerate(
+ zip(cls_scores, bbox_preds, mlvl_points)):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ points = points[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bbox_pos_center = torch.cat([points[:, :2], points[:, :2]], dim=1)
+ bboxes = bbox_pred * self.point_strides[i_lvl] + bbox_pos_center
+ x1 = bboxes[:, 0].clamp(min=0, max=img_shape[1])
+ y1 = bboxes[:, 1].clamp(min=0, max=img_shape[0])
+ x2 = bboxes[:, 2].clamp(min=0, max=img_shape[1])
+ y2 = bboxes[:, 3].clamp(min=0, max=img_shape[0])
+ bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ if with_nms:
+ det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
+ else:
+ return mlvl_bboxes, mlvl_scores
diff --git a/mmdet/models/dense_heads/retina_head.py b/mmdet/models/dense_heads/retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..b12416fa8332f02b9a04bbfc7926f6d13875e61b
--- /dev/null
+++ b/mmdet/models/dense_heads/retina_head.py
@@ -0,0 +1,114 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class RetinaHead(AnchorHead):
+ r"""An anchor-based head used in `RetinaNet
+ `_.
+
+ The head contains two subnetworks. The first classifies anchor boxes and
+ the second regresses deltas for the anchors.
+
+ Example:
+ >>> import torch
+ >>> self = RetinaHead(11, 7)
+ >>> x = torch.rand(1, 7, 32, 32)
+ >>> cls_score, bbox_pred = self.forward_single(x)
+ >>> # Each anchor predicts a score for each class except background
+ >>> cls_per_anchor = cls_score.shape[1] / self.num_anchors
+ >>> box_per_anchor = bbox_pred.shape[1] / self.num_anchors
+ >>> assert cls_per_anchor == (self.num_classes)
+ >>> assert box_per_anchor == 4
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(RetinaHead, self).__init__(
+ num_classes,
+ in_channels,
+ anchor_generator=anchor_generator,
+ **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+ normal_init(self.retina_reg, std=0.01)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale
+ level, the channels number is num_anchors * 4.
+ """
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ return cls_score, bbox_pred
diff --git a/mmdet/models/dense_heads/retina_sepbn_head.py b/mmdet/models/dense_heads/retina_sepbn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b8ce7f0104b90af4b128e0f245473a1c0219fcd
--- /dev/null
+++ b/mmdet/models/dense_heads/retina_sepbn_head.py
@@ -0,0 +1,113 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class RetinaSepBNHead(AnchorHead):
+ """"RetinaHead with separate BN.
+
+ In RetinaHead, conv/norm layers are shared across different FPN levels,
+ while in RetinaSepBNHead, conv layers are shared across different FPN
+ levels, but BN layers are separated.
+ """
+
+ def __init__(self,
+ num_classes,
+ num_ins,
+ in_channels,
+ stacked_convs=4,
+ conv_cfg=None,
+ norm_cfg=None,
+ **kwargs):
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.num_ins = num_ins
+ super(RetinaSepBNHead, self).__init__(num_classes, in_channels,
+ **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.num_ins):
+ cls_convs = nn.ModuleList()
+ reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.cls_convs.append(cls_convs)
+ self.reg_convs.append(reg_convs)
+ for i in range(self.stacked_convs):
+ for j in range(1, self.num_ins):
+ self.cls_convs[j][i].conv = self.cls_convs[0][i].conv
+ self.reg_convs[j][i].conv = self.reg_convs[0][i].conv
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.retina_reg = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs[0]:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs[0]:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+ normal_init(self.retina_reg, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple: Usually a tuple of classification scores and bbox prediction
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ cls_scores = []
+ bbox_preds = []
+ for i, x in enumerate(feats):
+ cls_feat = feats[i]
+ reg_feat = feats[i]
+ for cls_conv in self.cls_convs[i]:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs[i]:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_pred = self.retina_reg(reg_feat)
+ cls_scores.append(cls_score)
+ bbox_preds.append(bbox_pred)
+ return cls_scores, bbox_preds
diff --git a/mmdet/models/dense_heads/rpn_head.py b/mmdet/models/dense_heads/rpn_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a888cb8c188ca6fe63045b6230266553fbe8c996
--- /dev/null
+++ b/mmdet/models/dense_heads/rpn_head.py
@@ -0,0 +1,236 @@
+import copy
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv import ConfigDict
+from mmcv.cnn import normal_init
+from mmcv.ops import batched_nms
+
+from ..builder import HEADS
+from .anchor_head import AnchorHead
+from .rpn_test_mixin import RPNTestMixin
+
+
+@HEADS.register_module()
+class RPNHead(RPNTestMixin, AnchorHead):
+ """RPN head.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ """ # noqa: W605
+
+ def __init__(self, in_channels, **kwargs):
+ super(RPNHead, self).__init__(1, in_channels, **kwargs)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.rpn_conv = nn.Conv2d(
+ self.in_channels, self.feat_channels, 3, padding=1)
+ self.rpn_cls = nn.Conv2d(self.feat_channels,
+ self.num_anchors * self.cls_out_channels, 1)
+ self.rpn_reg = nn.Conv2d(self.feat_channels, self.num_anchors * 4, 1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ normal_init(self.rpn_conv, std=0.01)
+ normal_init(self.rpn_cls, std=0.01)
+ normal_init(self.rpn_reg, std=0.01)
+
+ def forward_single(self, x):
+ """Forward feature map of a single scale level."""
+ x = self.rpn_conv(x)
+ x = F.relu(x, inplace=True)
+ rpn_cls_score = self.rpn_cls(x)
+ rpn_bbox_pred = self.rpn_reg(x)
+ return rpn_cls_score, rpn_bbox_pred
+
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ losses = super(RPNHead, self).loss(
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ None,
+ img_metas,
+ gt_bboxes_ignore=gt_bboxes_ignore)
+ return dict(
+ loss_rpn_cls=losses['loss_cls'], loss_rpn_bbox=losses['loss_bbox'])
+
+ def _get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_anchors,
+ img_shapes,
+ scale_factors,
+ cfg,
+ rescale=False):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for each scale level
+ with shape (num_total_anchors, 4).
+ img_shapes (list[tuple[int]]): Shape of the input image,
+ (height, width, 3).
+ scale_factors (list[ndarray]): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class labelof the
+ corresponding box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ cfg = copy.deepcopy(cfg)
+ # bboxes from different level should be independent during NMS,
+ # level_ids are used as labels for batched NMS to separate them
+ level_ids = []
+ mlvl_scores = []
+ mlvl_bbox_preds = []
+ mlvl_valid_anchors = []
+ batch_size = cls_scores[0].shape[0]
+ nms_pre_tensor = torch.tensor(
+ cfg.nms_pre, device=cls_scores[0].device, dtype=torch.long)
+ for idx in range(len(cls_scores)):
+ rpn_cls_score = cls_scores[idx]
+ rpn_bbox_pred = bbox_preds[idx]
+ assert rpn_cls_score.size()[-2:] == rpn_bbox_pred.size()[-2:]
+ rpn_cls_score = rpn_cls_score.permute(0, 2, 3, 1)
+ if self.use_sigmoid_cls:
+ rpn_cls_score = rpn_cls_score.reshape(batch_size, -1)
+ scores = rpn_cls_score.sigmoid()
+ else:
+ rpn_cls_score = rpn_cls_score.reshape(batch_size, -1, 2)
+ # We set FG labels to [0, num_class-1] and BG label to
+ # num_class in RPN head since mmdet v2.5, which is unified to
+ # be consistent with other head since mmdet v2.0. In mmdet v2.0
+ # to v2.4 we keep BG label as 0 and FG label as 1 in rpn head.
+ scores = rpn_cls_score.softmax(-1)[..., 0]
+ rpn_bbox_pred = rpn_bbox_pred.permute(0, 2, 3, 1).reshape(
+ batch_size, -1, 4)
+ anchors = mlvl_anchors[idx]
+ anchors = anchors.expand_as(rpn_bbox_pred)
+ if nms_pre_tensor > 0:
+ # sort is faster than topk
+ # _, topk_inds = scores.topk(cfg.nms_pre)
+ # keep topk op for dynamic k in onnx model
+ if torch.onnx.is_in_onnx_export():
+ # sort op will be converted to TopK in onnx
+ # and k<=3480 in TensorRT
+ scores_shape = torch._shape_as_tensor(scores)
+ nms_pre = torch.where(scores_shape[1] < nms_pre_tensor,
+ scores_shape[1], nms_pre_tensor)
+ _, topk_inds = scores.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds)
+ scores = scores[batch_inds, topk_inds]
+ rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
+ anchors = anchors[batch_inds, topk_inds, :]
+
+ elif scores.shape[-1] > cfg.nms_pre:
+ ranked_scores, rank_inds = scores.sort(descending=True)
+ topk_inds = rank_inds[:, :cfg.nms_pre]
+ scores = ranked_scores[:, :cfg.nms_pre]
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds)
+ rpn_bbox_pred = rpn_bbox_pred[batch_inds, topk_inds, :]
+ anchors = anchors[batch_inds, topk_inds, :]
+
+ mlvl_scores.append(scores)
+ mlvl_bbox_preds.append(rpn_bbox_pred)
+ mlvl_valid_anchors.append(anchors)
+ level_ids.append(
+ scores.new_full((
+ batch_size,
+ scores.size(1),
+ ),
+ idx,
+ dtype=torch.long))
+
+ batch_mlvl_scores = torch.cat(mlvl_scores, dim=1)
+ batch_mlvl_anchors = torch.cat(mlvl_valid_anchors, dim=1)
+ batch_mlvl_rpn_bbox_pred = torch.cat(mlvl_bbox_preds, dim=1)
+ batch_mlvl_proposals = self.bbox_coder.decode(
+ batch_mlvl_anchors, batch_mlvl_rpn_bbox_pred, max_shape=img_shapes)
+ batch_mlvl_ids = torch.cat(level_ids, dim=1)
+
+ # deprecate arguments warning
+ if 'nms' not in cfg or 'max_num' in cfg or 'nms_thr' in cfg:
+ warnings.warn(
+ 'In rpn_proposal or test_cfg, '
+ 'nms_thr has been moved to a dict named nms as '
+ 'iou_threshold, max_num has been renamed as max_per_img, '
+ 'name of original arguments and the way to specify '
+ 'iou_threshold of NMS will be deprecated.')
+ if 'nms' not in cfg:
+ cfg.nms = ConfigDict(dict(type='nms', iou_threshold=cfg.nms_thr))
+ if 'max_num' in cfg:
+ if 'max_per_img' in cfg:
+ assert cfg.max_num == cfg.max_per_img, f'You ' \
+ f'set max_num and ' \
+ f'max_per_img at the same time, but get {cfg.max_num} ' \
+ f'and {cfg.max_per_img} respectively' \
+ 'Please delete max_num which will be deprecated.'
+ else:
+ cfg.max_per_img = cfg.max_num
+ if 'nms_thr' in cfg:
+ assert cfg.nms.iou_threshold == cfg.nms_thr, f'You set' \
+ f' iou_threshold in nms and ' \
+ f'nms_thr at the same time, but get' \
+ f' {cfg.nms.iou_threshold} and {cfg.nms_thr}' \
+ f' respectively. Please delete the nms_thr ' \
+ f'which will be deprecated.'
+
+ result_list = []
+ for (mlvl_proposals, mlvl_scores,
+ mlvl_ids) in zip(batch_mlvl_proposals, batch_mlvl_scores,
+ batch_mlvl_ids):
+ # Skip nonzero op while exporting to ONNX
+ if cfg.min_bbox_size > 0 and (not torch.onnx.is_in_onnx_export()):
+ w = mlvl_proposals[:, 2] - mlvl_proposals[:, 0]
+ h = mlvl_proposals[:, 3] - mlvl_proposals[:, 1]
+ valid_ind = torch.nonzero(
+ (w >= cfg.min_bbox_size)
+ & (h >= cfg.min_bbox_size),
+ as_tuple=False).squeeze()
+ if valid_ind.sum().item() != len(mlvl_proposals):
+ mlvl_proposals = mlvl_proposals[valid_ind, :]
+ mlvl_scores = mlvl_scores[valid_ind]
+ mlvl_ids = mlvl_ids[valid_ind]
+
+ dets, keep = batched_nms(mlvl_proposals, mlvl_scores, mlvl_ids,
+ cfg.nms)
+ result_list.append(dets[:cfg.max_per_img])
+ return result_list
diff --git a/mmdet/models/dense_heads/rpn_test_mixin.py b/mmdet/models/dense_heads/rpn_test_mixin.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ce5c66f82595f496e6e55719c1caee75150d568
--- /dev/null
+++ b/mmdet/models/dense_heads/rpn_test_mixin.py
@@ -0,0 +1,59 @@
+import sys
+
+from mmdet.core import merge_aug_proposals
+
+if sys.version_info >= (3, 7):
+ from mmdet.utils.contextmanagers import completed
+
+
+class RPNTestMixin(object):
+ """Test methods of RPN."""
+
+ if sys.version_info >= (3, 7):
+
+ async def async_simple_test_rpn(self, x, img_metas):
+ sleep_interval = self.test_cfg.pop('async_sleep_interval', 0.025)
+ async with completed(
+ __name__, 'rpn_head_forward',
+ sleep_interval=sleep_interval):
+ rpn_outs = self(x)
+
+ proposal_list = self.get_bboxes(*rpn_outs, img_metas)
+ return proposal_list
+
+ def simple_test_rpn(self, x, img_metas):
+ """Test without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Proposals of each image.
+ """
+ rpn_outs = self(x)
+ proposal_list = self.get_bboxes(*rpn_outs, img_metas)
+ return proposal_list
+
+ def aug_test_rpn(self, feats, img_metas):
+ samples_per_gpu = len(img_metas[0])
+ aug_proposals = [[] for _ in range(samples_per_gpu)]
+ for x, img_meta in zip(feats, img_metas):
+ proposal_list = self.simple_test_rpn(x, img_meta)
+ for i, proposals in enumerate(proposal_list):
+ aug_proposals[i].append(proposals)
+ # reorganize the order of 'img_metas' to match the dimensions
+ # of 'aug_proposals'
+ aug_img_metas = []
+ for i in range(samples_per_gpu):
+ aug_img_meta = []
+ for j in range(len(img_metas)):
+ aug_img_meta.append(img_metas[j][i])
+ aug_img_metas.append(aug_img_meta)
+ # after merging, proposals will be rescaled to the original image size
+ merged_proposals = [
+ merge_aug_proposals(proposals, aug_img_meta, self.test_cfg)
+ for proposals, aug_img_meta in zip(aug_proposals, aug_img_metas)
+ ]
+ return merged_proposals
diff --git a/mmdet/models/dense_heads/sabl_retina_head.py b/mmdet/models/dense_heads/sabl_retina_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4211622cb8b4fe807230a89bcaab8f4f1681bfc0
--- /dev/null
+++ b/mmdet/models/dense_heads/sabl_retina_head.py
@@ -0,0 +1,621 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, bias_init_with_prob, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_anchor_generator, build_assigner,
+ build_bbox_coder, build_sampler, images_to_levels,
+ multi_apply, multiclass_nms, unmap)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .guided_anchor_head import GuidedAnchorHead
+
+
+@HEADS.register_module()
+class SABLRetinaHead(BaseDenseHead):
+ """Side-Aware Boundary Localization (SABL) for RetinaNet.
+
+ The anchor generation, assigning and sampling in SABLRetinaHead
+ are the same as GuidedAnchorHead for guided anchoring.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ num_classes (int): Number of classes.
+ in_channels (int): Number of channels in the input feature map.
+ stacked_convs (int): Number of Convs for classification \
+ and regression branches. Defaults to 4.
+ feat_channels (int): Number of hidden channels. \
+ Defaults to 256.
+ approx_anchor_generator (dict): Config dict for approx generator.
+ square_anchor_generator (dict): Config dict for square generator.
+ conv_cfg (dict): Config dict for ConvModule. Defaults to None.
+ norm_cfg (dict): Config dict for Norm Layer. Defaults to None.
+ bbox_coder (dict): Config dict for bbox coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ train_cfg (dict): Training config of SABLRetinaHead.
+ test_cfg (dict): Testing config of SABLRetinaHead.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_cls (dict): Config of classification loss for bbox branch.
+ loss_bbox_reg (dict): Config of regression loss for bbox branch.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ stacked_convs=4,
+ feat_channels=256,
+ approx_anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=4,
+ scales_per_octave=3,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ square_anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ scales=[4],
+ strides=[8, 16, 32, 64, 128]),
+ conv_cfg=None,
+ norm_cfg=None,
+ bbox_coder=dict(
+ type='BucketingBBoxCoder',
+ num_buckets=14,
+ scale_factor=3.0),
+ reg_decoded_bbox=False,
+ train_cfg=None,
+ test_cfg=None,
+ loss_cls=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ loss_bbox_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.5),
+ loss_bbox_reg=dict(
+ type='SmoothL1Loss', beta=1.0 / 9.0, loss_weight=1.5)):
+ super(SABLRetinaHead, self).__init__()
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.feat_channels = feat_channels
+ self.num_buckets = bbox_coder['num_buckets']
+ self.side_num = int(np.ceil(self.num_buckets / 2))
+
+ assert (approx_anchor_generator['octave_base_scale'] ==
+ square_anchor_generator['scales'][0])
+ assert (approx_anchor_generator['strides'] ==
+ square_anchor_generator['strides'])
+
+ self.approx_anchor_generator = build_anchor_generator(
+ approx_anchor_generator)
+ self.square_anchor_generator = build_anchor_generator(
+ square_anchor_generator)
+ self.approxs_per_octave = (
+ self.approx_anchor_generator.num_base_anchors[0])
+
+ # one anchor per location
+ self.num_anchors = 1
+ self.stacked_convs = stacked_convs
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.reg_decoded_bbox = reg_decoded_bbox
+
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ self.sampling = loss_cls['type'] not in [
+ 'FocalLoss', 'GHMC', 'QualityFocalLoss'
+ ]
+ if self.use_sigmoid_cls:
+ self.cls_out_channels = num_classes
+ else:
+ self.cls_out_channels = num_classes + 1
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox_cls = build_loss(loss_bbox_cls)
+ self.loss_bbox_reg = build_loss(loss_bbox_reg)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # use PseudoSampler when sampling is False
+ if self.sampling and hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.fp16_enabled = False
+ self._init_layers()
+
+ def _init_layers(self):
+ self.relu = nn.ReLU(inplace=True)
+ self.cls_convs = nn.ModuleList()
+ self.reg_convs = nn.ModuleList()
+ for i in range(self.stacked_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.cls_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.reg_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.retina_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+ self.retina_bbox_reg = nn.Conv2d(
+ self.feat_channels, self.side_num * 4, 3, padding=1)
+ self.retina_bbox_cls = nn.Conv2d(
+ self.feat_channels, self.side_num * 4, 3, padding=1)
+
+ def init_weights(self):
+ for m in self.cls_convs:
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ normal_init(m.conv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.retina_cls, std=0.01, bias=bias_cls)
+ normal_init(self.retina_bbox_reg, std=0.01)
+ normal_init(self.retina_bbox_cls, std=0.01)
+
+ def forward_single(self, x):
+ cls_feat = x
+ reg_feat = x
+ for cls_conv in self.cls_convs:
+ cls_feat = cls_conv(cls_feat)
+ for reg_conv in self.reg_convs:
+ reg_feat = reg_conv(reg_feat)
+ cls_score = self.retina_cls(cls_feat)
+ bbox_cls_pred = self.retina_bbox_cls(reg_feat)
+ bbox_reg_pred = self.retina_bbox_reg(reg_feat)
+ bbox_pred = (bbox_cls_pred, bbox_reg_pred)
+ return cls_score, bbox_pred
+
+ def forward(self, feats):
+ return multi_apply(self.forward_single, feats)
+
+ def get_anchors(self, featmap_sizes, img_metas, device='cuda'):
+ """Get squares according to feature map sizes and guided anchors.
+
+ Args:
+ featmap_sizes (list[tuple]): Multi-level feature map sizes.
+ img_metas (list[dict]): Image meta info.
+ device (torch.device | str): device for returned tensors
+
+ Returns:
+ tuple: square approxs of each image
+ """
+ num_imgs = len(img_metas)
+
+ # since feature map sizes of all images are the same, we only compute
+ # squares for one time
+ multi_level_squares = self.square_anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+ squares_list = [multi_level_squares for _ in range(num_imgs)]
+
+ return squares_list
+
+ def get_target(self,
+ approx_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes_list,
+ img_metas,
+ gt_bboxes_ignore_list=None,
+ gt_labels_list=None,
+ label_channels=None,
+ sampling=True,
+ unmap_outputs=True):
+ """Compute bucketing targets.
+ Args:
+ approx_list (list[list]): Multi level approxs of each image.
+ inside_flag_list (list[list]): Multi level inside flags of each
+ image.
+ square_list (list[list]): Multi level squares of each image.
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ img_metas (list[dict]): Meta info of each image.
+ gt_bboxes_ignore_list (list[Tensor]): ignore list of gt bboxes.
+ gt_bboxes_list (list[Tensor]): Gt bboxes of each image.
+ label_channels (int): Channel of label.
+ sampling (bool): Sample Anchors or not.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple: Returns a tuple containing learning targets.
+
+ - labels_list (list[Tensor]): Labels of each level.
+ - label_weights_list (list[Tensor]): Label weights of each \
+ level.
+ - bbox_cls_targets_list (list[Tensor]): BBox cls targets of \
+ each level.
+ - bbox_cls_weights_list (list[Tensor]): BBox cls weights of \
+ each level.
+ - bbox_reg_targets_list (list[Tensor]): BBox reg targets of \
+ each level.
+ - bbox_reg_weights_list (list[Tensor]): BBox reg weights of \
+ each level.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ num_imgs = len(img_metas)
+ assert len(approx_list) == len(inside_flag_list) == len(
+ square_list) == num_imgs
+ # anchor number of multi levels
+ num_level_squares = [squares.size(0) for squares in square_list[0]]
+ # concat all level anchors and flags to a single tensor
+ inside_flag_flat_list = []
+ approx_flat_list = []
+ square_flat_list = []
+ for i in range(num_imgs):
+ assert len(square_list[i]) == len(inside_flag_list[i])
+ inside_flag_flat_list.append(torch.cat(inside_flag_list[i]))
+ approx_flat_list.append(torch.cat(approx_list[i]))
+ square_flat_list.append(torch.cat(square_list[i]))
+
+ # compute targets for each image
+ if gt_bboxes_ignore_list is None:
+ gt_bboxes_ignore_list = [None for _ in range(num_imgs)]
+ if gt_labels_list is None:
+ gt_labels_list = [None for _ in range(num_imgs)]
+ (all_labels, all_label_weights, all_bbox_cls_targets,
+ all_bbox_cls_weights, all_bbox_reg_targets, all_bbox_reg_weights,
+ pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single,
+ approx_flat_list,
+ inside_flag_flat_list,
+ square_flat_list,
+ gt_bboxes_list,
+ gt_bboxes_ignore_list,
+ gt_labels_list,
+ img_metas,
+ label_channels=label_channels,
+ sampling=sampling,
+ unmap_outputs=unmap_outputs)
+ # no valid anchors
+ if any([labels is None for labels in all_labels]):
+ return None
+ # sampled anchors of all images
+ num_total_pos = sum([max(inds.numel(), 1) for inds in pos_inds_list])
+ num_total_neg = sum([max(inds.numel(), 1) for inds in neg_inds_list])
+ # split targets to a list w.r.t. multiple levels
+ labels_list = images_to_levels(all_labels, num_level_squares)
+ label_weights_list = images_to_levels(all_label_weights,
+ num_level_squares)
+ bbox_cls_targets_list = images_to_levels(all_bbox_cls_targets,
+ num_level_squares)
+ bbox_cls_weights_list = images_to_levels(all_bbox_cls_weights,
+ num_level_squares)
+ bbox_reg_targets_list = images_to_levels(all_bbox_reg_targets,
+ num_level_squares)
+ bbox_reg_weights_list = images_to_levels(all_bbox_reg_weights,
+ num_level_squares)
+ return (labels_list, label_weights_list, bbox_cls_targets_list,
+ bbox_cls_weights_list, bbox_reg_targets_list,
+ bbox_reg_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self,
+ flat_approxs,
+ inside_flags,
+ flat_squares,
+ gt_bboxes,
+ gt_bboxes_ignore,
+ gt_labels,
+ img_meta,
+ label_channels=None,
+ sampling=True,
+ unmap_outputs=True):
+ """Compute regression and classification targets for anchors in a
+ single image.
+
+ Args:
+ flat_approxs (Tensor): flat approxs of a single image,
+ shape (n, 4)
+ inside_flags (Tensor): inside flags of a single image,
+ shape (n, ).
+ flat_squares (Tensor): flat squares of a single image,
+ shape (approxs_per_octave * n, 4)
+ gt_bboxes (Tensor): Ground truth bboxes of a single image, \
+ shape (num_gts, 4).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ img_meta (dict): Meta info of the image.
+ label_channels (int): Channel of label.
+ sampling (bool): Sample Anchors or not.
+ unmap_outputs (bool): unmap outputs or not.
+
+ Returns:
+ tuple:
+
+ - labels_list (Tensor): Labels in a single image
+ - label_weights (Tensor): Label weights in a single image
+ - bbox_cls_targets (Tensor): BBox cls targets in a single image
+ - bbox_cls_weights (Tensor): BBox cls weights in a single image
+ - bbox_reg_targets (Tensor): BBox reg targets in a single image
+ - bbox_reg_weights (Tensor): BBox reg weights in a single image
+ - num_total_pos (int): Number of positive samples \
+ in a single image
+ - num_total_neg (int): Number of negative samples \
+ in a single image
+ """
+ if not inside_flags.any():
+ return (None, ) * 8
+ # assign gt and sample anchors
+ expand_inside_flags = inside_flags[:, None].expand(
+ -1, self.approxs_per_octave).reshape(-1)
+ approxs = flat_approxs[expand_inside_flags, :]
+ squares = flat_squares[inside_flags, :]
+
+ assign_result = self.assigner.assign(approxs, squares,
+ self.approxs_per_octave,
+ gt_bboxes, gt_bboxes_ignore)
+ sampling_result = self.sampler.sample(assign_result, squares,
+ gt_bboxes)
+
+ num_valid_squares = squares.shape[0]
+ bbox_cls_targets = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_cls_weights = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_reg_targets = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ bbox_reg_weights = squares.new_zeros(
+ (num_valid_squares, self.side_num * 4))
+ labels = squares.new_full((num_valid_squares, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = squares.new_zeros(num_valid_squares, dtype=torch.float)
+
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+ if len(pos_inds) > 0:
+ (pos_bbox_reg_targets, pos_bbox_reg_weights, pos_bbox_cls_targets,
+ pos_bbox_cls_weights) = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes)
+
+ bbox_cls_targets[pos_inds, :] = pos_bbox_cls_targets
+ bbox_reg_targets[pos_inds, :] = pos_bbox_reg_targets
+ bbox_cls_weights[pos_inds, :] = pos_bbox_cls_weights
+ bbox_reg_weights[pos_inds, :] = pos_bbox_reg_weights
+ if gt_labels is None:
+ # Only rpn gives gt_labels as None
+ # Foreground is the first class
+ labels[pos_inds] = 0
+ else:
+ labels[pos_inds] = gt_labels[
+ sampling_result.pos_assigned_gt_inds]
+ if self.train_cfg.pos_weight <= 0:
+ label_weights[pos_inds] = 1.0
+ else:
+ label_weights[pos_inds] = self.train_cfg.pos_weight
+ if len(neg_inds) > 0:
+ label_weights[neg_inds] = 1.0
+
+ # map up to original set of anchors
+ if unmap_outputs:
+ num_total_anchors = flat_squares.size(0)
+ labels = unmap(
+ labels, num_total_anchors, inside_flags, fill=self.num_classes)
+ label_weights = unmap(label_weights, num_total_anchors,
+ inside_flags)
+ bbox_cls_targets = unmap(bbox_cls_targets, num_total_anchors,
+ inside_flags)
+ bbox_cls_weights = unmap(bbox_cls_weights, num_total_anchors,
+ inside_flags)
+ bbox_reg_targets = unmap(bbox_reg_targets, num_total_anchors,
+ inside_flags)
+ bbox_reg_weights = unmap(bbox_reg_weights, num_total_anchors,
+ inside_flags)
+ return (labels, label_weights, bbox_cls_targets, bbox_cls_weights,
+ bbox_reg_targets, bbox_reg_weights, pos_inds, neg_inds)
+
+ def loss_single(self, cls_score, bbox_pred, labels, label_weights,
+ bbox_cls_targets, bbox_cls_weights, bbox_reg_targets,
+ bbox_reg_weights, num_total_samples):
+ # classification loss
+ labels = labels.reshape(-1)
+ label_weights = label_weights.reshape(-1)
+ cls_score = cls_score.permute(0, 2, 3,
+ 1).reshape(-1, self.cls_out_channels)
+ loss_cls = self.loss_cls(
+ cls_score, labels, label_weights, avg_factor=num_total_samples)
+ # regression loss
+ bbox_cls_targets = bbox_cls_targets.reshape(-1, self.side_num * 4)
+ bbox_cls_weights = bbox_cls_weights.reshape(-1, self.side_num * 4)
+ bbox_reg_targets = bbox_reg_targets.reshape(-1, self.side_num * 4)
+ bbox_reg_weights = bbox_reg_weights.reshape(-1, self.side_num * 4)
+ (bbox_cls_pred, bbox_reg_pred) = bbox_pred
+ bbox_cls_pred = bbox_cls_pred.permute(0, 2, 3, 1).reshape(
+ -1, self.side_num * 4)
+ bbox_reg_pred = bbox_reg_pred.permute(0, 2, 3, 1).reshape(
+ -1, self.side_num * 4)
+ loss_bbox_cls = self.loss_bbox_cls(
+ bbox_cls_pred,
+ bbox_cls_targets.long(),
+ bbox_cls_weights,
+ avg_factor=num_total_samples * 4 * self.side_num)
+ loss_bbox_reg = self.loss_bbox_reg(
+ bbox_reg_pred,
+ bbox_reg_targets,
+ bbox_reg_weights,
+ avg_factor=num_total_samples * 4 * self.bbox_coder.offset_topk)
+ return loss_cls, loss_bbox_cls, loss_bbox_reg
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.approx_anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ # get sampled approxes
+ approxs_list, inside_flag_list = GuidedAnchorHead.get_sampled_approxs(
+ self, featmap_sizes, img_metas, device=device)
+
+ square_list = self.get_anchors(featmap_sizes, img_metas, device=device)
+
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = self.get_target(
+ approxs_list,
+ inside_flag_list,
+ square_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ sampling=self.sampling)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_cls_targets_list,
+ bbox_cls_weights_list, bbox_reg_targets_list, bbox_reg_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ num_total_samples = (
+ num_total_pos + num_total_neg if self.sampling else num_total_pos)
+ losses_cls, losses_bbox_cls, losses_bbox_reg = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ labels_list,
+ label_weights_list,
+ bbox_cls_targets_list,
+ bbox_cls_weights_list,
+ bbox_reg_targets_list,
+ bbox_reg_weights_list,
+ num_total_samples=num_total_samples)
+ return dict(
+ loss_cls=losses_cls,
+ loss_bbox_cls=losses_bbox_cls,
+ loss_bbox_reg=losses_bbox_reg)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+
+ device = cls_scores[0].device
+ mlvl_anchors = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_cls_pred_list = [
+ bbox_preds[i][0][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_reg_pred_list = [
+ bbox_preds[i][1][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self.get_bboxes_single(cls_score_list,
+ bbox_cls_pred_list,
+ bbox_reg_pred_list,
+ mlvl_anchors[img_id], img_shape,
+ scale_factor, cfg, rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def get_bboxes_single(self,
+ cls_scores,
+ bbox_cls_preds,
+ bbox_reg_preds,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ cfg = self.test_cfg if cfg is None else cfg
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_confids = []
+ assert len(cls_scores) == len(bbox_cls_preds) == len(
+ bbox_reg_preds) == len(mlvl_anchors)
+ for cls_score, bbox_cls_pred, bbox_reg_pred, anchors in zip(
+ cls_scores, bbox_cls_preds, bbox_reg_preds, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_cls_pred.size(
+ )[-2:] == bbox_reg_pred.size()[-2::]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_cls_pred = bbox_cls_pred.permute(1, 2, 0).reshape(
+ -1, self.side_num * 4)
+ bbox_reg_pred = bbox_reg_pred.permute(1, 2, 0).reshape(
+ -1, self.side_num * 4)
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_cls_pred = bbox_cls_pred[topk_inds, :]
+ bbox_reg_pred = bbox_reg_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bbox_preds = [
+ bbox_cls_pred.contiguous(),
+ bbox_reg_pred.contiguous()
+ ]
+ bboxes, confids = self.bbox_coder.decode(
+ anchors.contiguous(), bbox_preds, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_confids.append(confids)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ mlvl_confids = torch.cat(mlvl_confids)
+ if self.use_sigmoid_cls:
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ det_bboxes, det_labels = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=mlvl_confids)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/ssd_head.py b/mmdet/models/dense_heads/ssd_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..145622b64e3f0b3f7f518fc61a2a01348ebfa4f3
--- /dev/null
+++ b/mmdet/models/dense_heads/ssd_head.py
@@ -0,0 +1,265 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import xavier_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_anchor_generator, build_assigner,
+ build_bbox_coder, build_sampler, multi_apply)
+from ..builder import HEADS
+from ..losses import smooth_l1_loss
+from .anchor_head import AnchorHead
+
+
+# TODO: add loss evaluator for SSD
+@HEADS.register_module()
+class SSDHead(AnchorHead):
+ """SSD head used in https://arxiv.org/abs/1512.02325.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ reg_decoded_bbox (bool): If true, the regression loss would be
+ applied directly on decoded bounding boxes, converting both
+ the predicted boxes and regression targets to absolute
+ coordinates format. Default False. It should be `True` when
+ using `IoULoss`, `GIoULoss`, or `DIoULoss` in the bbox head.
+ train_cfg (dict): Training config of anchor head.
+ test_cfg (dict): Testing config of anchor head.
+ """ # noqa: W605
+
+ def __init__(self,
+ num_classes=80,
+ in_channels=(512, 1024, 512, 256, 256, 256),
+ anchor_generator=dict(
+ type='SSDAnchorGenerator',
+ scale_major=False,
+ input_size=300,
+ strides=[8, 16, 32, 64, 100, 300],
+ ratios=([2], [2, 3], [2, 3], [2, 3], [2], [2]),
+ basesize_ratio_range=(0.1, 0.9)),
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=[.0, .0, .0, .0],
+ target_stds=[1.0, 1.0, 1.0, 1.0],
+ ),
+ reg_decoded_bbox=False,
+ train_cfg=None,
+ test_cfg=None):
+ super(AnchorHead, self).__init__()
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.cls_out_channels = num_classes + 1 # add background class
+ self.anchor_generator = build_anchor_generator(anchor_generator)
+ num_anchors = self.anchor_generator.num_base_anchors
+
+ reg_convs = []
+ cls_convs = []
+ for i in range(len(in_channels)):
+ reg_convs.append(
+ nn.Conv2d(
+ in_channels[i],
+ num_anchors[i] * 4,
+ kernel_size=3,
+ padding=1))
+ cls_convs.append(
+ nn.Conv2d(
+ in_channels[i],
+ num_anchors[i] * (num_classes + 1),
+ kernel_size=3,
+ padding=1))
+ self.reg_convs = nn.ModuleList(reg_convs)
+ self.cls_convs = nn.ModuleList(cls_convs)
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.reg_decoded_bbox = reg_decoded_bbox
+ self.use_sigmoid_cls = False
+ self.cls_focal_loss = False
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ # set sampling=False for archor_target
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ # SSD sampling=False so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.fp16_enabled = False
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform', bias=0)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Classification scores for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * num_classes.
+ bbox_preds (list[Tensor]): Box energies / deltas for all scale
+ levels, each is a 4D-tensor, the channels number is
+ num_anchors * 4.
+ """
+ cls_scores = []
+ bbox_preds = []
+ for feat, reg_conv, cls_conv in zip(feats, self.reg_convs,
+ self.cls_convs):
+ cls_scores.append(cls_conv(feat))
+ bbox_preds.append(reg_conv(feat))
+ return cls_scores, bbox_preds
+
+ def loss_single(self, cls_score, bbox_pred, anchor, labels, label_weights,
+ bbox_targets, bbox_weights, num_total_samples):
+ """Compute loss of a single image.
+
+ Args:
+ cls_score (Tensor): Box scores for eachimage
+ Has shape (num_total_anchors, num_classes).
+ bbox_pred (Tensor): Box energies / deltas for each image
+ level with shape (num_total_anchors, 4).
+ anchors (Tensor): Box reference for each scale level with shape
+ (num_total_anchors, 4).
+ labels (Tensor): Labels of each anchors with shape
+ (num_total_anchors,).
+ label_weights (Tensor): Label weights of each anchor with shape
+ (num_total_anchors,)
+ bbox_targets (Tensor): BBox regression targets of each anchor wight
+ shape (num_total_anchors, 4).
+ bbox_weights (Tensor): BBox regression loss weights of each anchor
+ with shape (num_total_anchors, 4).
+ num_total_samples (int): If sampling, num total samples equal to
+ the number of total anchors; Otherwise, it is the number of
+ positive anchors.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+
+ loss_cls_all = F.cross_entropy(
+ cls_score, labels, reduction='none') * label_weights
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((labels >= 0) &
+ (labels < self.num_classes)).nonzero().reshape(-1)
+ neg_inds = (labels == self.num_classes).nonzero().view(-1)
+
+ num_pos_samples = pos_inds.size(0)
+ num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
+ if num_neg_samples > neg_inds.size(0):
+ num_neg_samples = neg_inds.size(0)
+ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
+ loss_cls_pos = loss_cls_all[pos_inds].sum()
+ loss_cls_neg = topk_loss_cls_neg.sum()
+ loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
+
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(anchor, bbox_pred)
+
+ loss_bbox = smooth_l1_loss(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ beta=self.train_cfg.smoothl1_beta,
+ avg_factor=num_total_samples)
+ return loss_cls[None], loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute losses of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=1,
+ unmap_outputs=False)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ # check NaN and Inf
+ assert torch.isfinite(all_cls_scores).all().item(), \
+ 'classification scores become infinite or NaN!'
+ assert torch.isfinite(all_bbox_preds).all().item(), \
+ 'bbox predications become infinite or NaN!'
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ return dict(loss_cls=losses_cls, loss_bbox=losses_bbox)
diff --git a/mmdet/models/dense_heads/transformer_head.py b/mmdet/models/dense_heads/transformer_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..820fd069fcca295f6102f0d27366158a8c640249
--- /dev/null
+++ b/mmdet/models/dense_heads/transformer_head.py
@@ -0,0 +1,654 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, Linear, build_activation_layer
+from mmcv.runner import force_fp32
+
+from mmdet.core import (bbox_cxcywh_to_xyxy, bbox_xyxy_to_cxcywh,
+ build_assigner, build_sampler, multi_apply,
+ reduce_mean)
+from mmdet.models.utils import (FFN, build_positional_encoding,
+ build_transformer)
+from ..builder import HEADS, build_loss
+from .anchor_free_head import AnchorFreeHead
+
+
+@HEADS.register_module()
+class TransformerHead(AnchorFreeHead):
+ """Implements the DETR transformer head.
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_classes (int): Number of categories excluding the background.
+ in_channels (int): Number of channels in the input feature map.
+ num_fcs (int, optional): Number of fully-connected layers used in
+ `FFN`, which is then used for the regression head. Default 2.
+ transformer (dict, optional): Config for transformer.
+ positional_encoding (dict, optional): Config for position encoding.
+ loss_cls (dict, optional): Config of the classification loss.
+ Default `CrossEntropyLoss`.
+ loss_bbox (dict, optional): Config of the regression loss.
+ Default `L1Loss`.
+ loss_iou (dict, optional): Config of the regression iou loss.
+ Default `GIoULoss`.
+ tran_cfg (dict, optional): Training config of transformer head.
+ test_cfg (dict, optional): Testing config of transformer head.
+
+ Example:
+ >>> import torch
+ >>> self = TransformerHead(80, 2048)
+ >>> x = torch.rand(1, 2048, 32, 32)
+ >>> mask = torch.ones(1, 32, 32).to(x.dtype)
+ >>> mask[:, :16, :15] = 0
+ >>> all_cls_scores, all_bbox_preds = self(x, mask)
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ num_fcs=2,
+ transformer=dict(
+ type='Transformer',
+ embed_dims=256,
+ num_heads=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ feedforward_channels=2048,
+ dropout=0.1,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2,
+ pre_norm=False,
+ return_intermediate_dec=True),
+ positional_encoding=dict(
+ type='SinePositionalEncoding',
+ num_feats=128,
+ normalize=True),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ bg_cls_weight=0.1,
+ use_sigmoid=False,
+ loss_weight=1.0,
+ class_weight=1.0),
+ loss_bbox=dict(type='L1Loss', loss_weight=5.0),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ train_cfg=dict(
+ assigner=dict(
+ type='HungarianAssigner',
+ cls_cost=dict(type='ClassificationCost', weight=1.),
+ reg_cost=dict(type='BBoxL1Cost', weight=5.0),
+ iou_cost=dict(
+ type='IoUCost', iou_mode='giou', weight=2.0))),
+ test_cfg=dict(max_per_img=100),
+ **kwargs):
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since it brings inconvenience when the initialization of
+ # `AnchorFreeHead` is called.
+ super(AnchorFreeHead, self).__init__()
+ use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ assert not use_sigmoid_cls, 'setting use_sigmoid_cls as True is ' \
+ 'not supported in DETR, since background is needed for the ' \
+ 'matching process.'
+ assert 'embed_dims' in transformer \
+ and 'num_feats' in positional_encoding
+ num_feats = positional_encoding['num_feats']
+ embed_dims = transformer['embed_dims']
+ assert num_feats * 2 == embed_dims, 'embed_dims should' \
+ f' be exactly 2 times of num_feats. Found {embed_dims}' \
+ f' and {num_feats}.'
+ assert test_cfg is not None and 'max_per_img' in test_cfg
+
+ class_weight = loss_cls.get('class_weight', None)
+ if class_weight is not None:
+ assert isinstance(class_weight, float), 'Expected ' \
+ 'class_weight to have type float. Found ' \
+ f'{type(class_weight)}.'
+ # NOTE following the official DETR rep0, bg_cls_weight means
+ # relative classification weight of the no-object class.
+ bg_cls_weight = loss_cls.get('bg_cls_weight', class_weight)
+ assert isinstance(bg_cls_weight, float), 'Expected ' \
+ 'bg_cls_weight to have type float. Found ' \
+ f'{type(bg_cls_weight)}.'
+ class_weight = torch.ones(num_classes + 1) * class_weight
+ # set background class as the last indice
+ class_weight[num_classes] = bg_cls_weight
+ loss_cls.update({'class_weight': class_weight})
+ if 'bg_cls_weight' in loss_cls:
+ loss_cls.pop('bg_cls_weight')
+ self.bg_cls_weight = bg_cls_weight
+
+ if train_cfg:
+ assert 'assigner' in train_cfg, 'assigner should be provided '\
+ 'when train_cfg is set.'
+ assigner = train_cfg['assigner']
+ assert loss_cls['loss_weight'] == assigner['cls_cost']['weight'], \
+ 'The classification weight for loss and matcher should be' \
+ 'exactly the same.'
+ assert loss_bbox['loss_weight'] == assigner['reg_cost'][
+ 'weight'], 'The regression L1 weight for loss and matcher ' \
+ 'should be exactly the same.'
+ assert loss_iou['loss_weight'] == assigner['iou_cost']['weight'], \
+ 'The regression iou weight for loss and matcher should be' \
+ 'exactly the same.'
+ self.assigner = build_assigner(assigner)
+ # DETR sampling=False, so use PseudoSampler
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.num_classes = num_classes
+ self.cls_out_channels = num_classes + 1
+ self.in_channels = in_channels
+ self.num_fcs = num_fcs
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.use_sigmoid_cls = use_sigmoid_cls
+ self.embed_dims = embed_dims
+ self.num_query = test_cfg['max_per_img']
+ self.fp16_enabled = False
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_iou = build_loss(loss_iou)
+ self.act_cfg = transformer.get('act_cfg',
+ dict(type='ReLU', inplace=True))
+ self.activate = build_activation_layer(self.act_cfg)
+ self.positional_encoding = build_positional_encoding(
+ positional_encoding)
+ self.transformer = build_transformer(transformer)
+ self._init_layers()
+
+ def _init_layers(self):
+ """Initialize layers of the transformer head."""
+ self.input_proj = Conv2d(
+ self.in_channels, self.embed_dims, kernel_size=1)
+ self.fc_cls = Linear(self.embed_dims, self.cls_out_channels)
+ self.reg_ffn = FFN(
+ self.embed_dims,
+ self.embed_dims,
+ self.num_fcs,
+ self.act_cfg,
+ dropout=0.0,
+ add_residual=False)
+ self.fc_reg = Linear(self.embed_dims, 4)
+ self.query_embedding = nn.Embedding(self.num_query, self.embed_dims)
+
+ def init_weights(self, distribution='uniform'):
+ """Initialize weights of the transformer head."""
+ # The initialization for transformer is important
+ self.transformer.init_weights()
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """load checkpoints."""
+ # NOTE here use `AnchorFreeHead` instead of `TransformerHead`,
+ # since `AnchorFreeHead._load_from_state_dict` should not be
+ # called here. Invoking the default `Module._load_from_state_dict`
+ # is enough.
+ super(AnchorFreeHead,
+ self)._load_from_state_dict(state_dict, prefix, local_metadata,
+ strict, missing_keys,
+ unexpected_keys, error_msgs)
+
+ def forward(self, feats, img_metas):
+ """Forward function.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: Outputs for all scale levels.
+
+ - all_cls_scores_list (list[Tensor]): Classification scores \
+ for each scale level. Each is a 4D-tensor with shape \
+ [nb_dec, bs, num_query, cls_out_channels]. Note \
+ `cls_out_channels` should includes background.
+ - all_bbox_preds_list (list[Tensor]): Sigmoid regression \
+ outputs for each scale level. Each is a 4D-tensor with \
+ normalized coordinate format (cx, cy, w, h) and shape \
+ [nb_dec, bs, num_query, 4].
+ """
+ num_levels = len(feats)
+ img_metas_list = [img_metas for _ in range(num_levels)]
+ return multi_apply(self.forward_single, feats, img_metas_list)
+
+ def forward_single(self, x, img_metas):
+ """"Forward function for a single feature level.
+
+ Args:
+ x (Tensor): Input feature from backbone's single stage, shape
+ [bs, c, h, w].
+ img_metas (list[dict]): List of image information.
+
+ Returns:
+ all_cls_scores (Tensor): Outputs from the classification head,
+ shape [nb_dec, bs, num_query, cls_out_channels]. Note
+ cls_out_channels should includes background.
+ all_bbox_preds (Tensor): Sigmoid outputs from the regression
+ head with normalized coordinate format (cx, cy, w, h).
+ Shape [nb_dec, bs, num_query, 4].
+ """
+ # construct binary masks which used for the transformer.
+ # NOTE following the official DETR repo, non-zero values representing
+ # ignored positions, while zero values means valid positions.
+ batch_size = x.size(0)
+ input_img_h, input_img_w = img_metas[0]['batch_input_shape']
+ masks = x.new_ones((batch_size, input_img_h, input_img_w))
+ for img_id in range(batch_size):
+ img_h, img_w, _ = img_metas[img_id]['img_shape']
+ masks[img_id, :img_h, :img_w] = 0
+
+ x = self.input_proj(x)
+ # interpolate masks to have the same spatial shape with x
+ masks = F.interpolate(
+ masks.unsqueeze(1), size=x.shape[-2:]).to(torch.bool).squeeze(1)
+ # position encoding
+ pos_embed = self.positional_encoding(masks) # [bs, embed_dim, h, w]
+ # outs_dec: [nb_dec, bs, num_query, embed_dim]
+ outs_dec, _ = self.transformer(x, masks, self.query_embedding.weight,
+ pos_embed)
+
+ all_cls_scores = self.fc_cls(outs_dec)
+ all_bbox_preds = self.fc_reg(self.activate(
+ self.reg_ffn(outs_dec))).sigmoid()
+ return all_cls_scores, all_bbox_preds
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def loss(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """"Loss function.
+
+ Only outputs from the last feature level are used for computing
+ losses by default.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore (list[Tensor], optional): Bounding boxes
+ which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ # NOTE defaultly only the outputs from the last feature scale is used.
+ all_cls_scores = all_cls_scores_list[-1]
+ all_bbox_preds = all_bbox_preds_list[-1]
+ assert gt_bboxes_ignore is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+
+ num_dec_layers = len(all_cls_scores)
+ all_gt_bboxes_list = [gt_bboxes_list for _ in range(num_dec_layers)]
+ all_gt_labels_list = [gt_labels_list for _ in range(num_dec_layers)]
+ all_gt_bboxes_ignore_list = [
+ gt_bboxes_ignore for _ in range(num_dec_layers)
+ ]
+ img_metas_list = [img_metas for _ in range(num_dec_layers)]
+
+ losses_cls, losses_bbox, losses_iou = multi_apply(
+ self.loss_single, all_cls_scores, all_bbox_preds,
+ all_gt_bboxes_list, all_gt_labels_list, img_metas_list,
+ all_gt_bboxes_ignore_list)
+
+ loss_dict = dict()
+ # loss from the last decoder layer
+ loss_dict['loss_cls'] = losses_cls[-1]
+ loss_dict['loss_bbox'] = losses_bbox[-1]
+ loss_dict['loss_iou'] = losses_iou[-1]
+ # loss from other decoder layers
+ num_dec_layer = 0
+ for loss_cls_i, loss_bbox_i, loss_iou_i in zip(losses_cls[:-1],
+ losses_bbox[:-1],
+ losses_iou[:-1]):
+ loss_dict[f'd{num_dec_layer}.loss_cls'] = loss_cls_i
+ loss_dict[f'd{num_dec_layer}.loss_bbox'] = loss_bbox_i
+ loss_dict[f'd{num_dec_layer}.loss_iou'] = loss_iou_i
+ num_dec_layer += 1
+ return loss_dict
+
+ def loss_single(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Loss function for outputs from a single decoder layer of a single
+ feature level.
+
+ Args:
+ cls_scores (Tensor): Box score logits from a single decoder layer
+ for all images. Shape [bs, num_query, cls_out_channels].
+ bbox_preds (Tensor): Sigmoid outputs from a single decoder layer
+ for all images, with normalized coordinate (cx, cy, w, h) and
+ shape [bs, num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components for outputs from
+ a single decoder layer.
+ """
+ num_imgs = cls_scores.size(0)
+ cls_scores_list = [cls_scores[i] for i in range(num_imgs)]
+ bbox_preds_list = [bbox_preds[i] for i in range(num_imgs)]
+ cls_reg_targets = self.get_targets(cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list,
+ img_metas, gt_bboxes_ignore_list)
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg) = cls_reg_targets
+ labels = torch.cat(labels_list, 0)
+ label_weights = torch.cat(label_weights_list, 0)
+ bbox_targets = torch.cat(bbox_targets_list, 0)
+ bbox_weights = torch.cat(bbox_weights_list, 0)
+
+ # classification loss
+ cls_scores = cls_scores.reshape(-1, self.cls_out_channels)
+ # construct weighted avg_factor to match with the official DETR repo
+ cls_avg_factor = num_total_pos * 1.0 + \
+ num_total_neg * self.bg_cls_weight
+ loss_cls = self.loss_cls(
+ cls_scores, labels, label_weights, avg_factor=cls_avg_factor)
+
+ # Compute the average number of gt boxes accross all gpus, for
+ # normalization purposes
+ num_total_pos = loss_cls.new_tensor([num_total_pos])
+ num_total_pos = torch.clamp(reduce_mean(num_total_pos), min=1).item()
+
+ # construct factors used for rescale bboxes
+ factors = []
+ for img_meta, bbox_pred in zip(img_metas, bbox_preds):
+ img_h, img_w, _ = img_meta['img_shape']
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0).repeat(
+ bbox_pred.size(0), 1)
+ factors.append(factor)
+ factors = torch.cat(factors, 0)
+
+ # DETR regress the relative position of boxes (cxcywh) in the image,
+ # thus the learning target is normalized by the image size. So here
+ # we need to re-scale them for calculating IoU loss
+ bbox_preds = bbox_preds.reshape(-1, 4)
+ bboxes = bbox_cxcywh_to_xyxy(bbox_preds) * factors
+ bboxes_gt = bbox_cxcywh_to_xyxy(bbox_targets) * factors
+
+ # regression IoU loss, defaultly GIoU loss
+ loss_iou = self.loss_iou(
+ bboxes, bboxes_gt, bbox_weights, avg_factor=num_total_pos)
+
+ # regression L1 loss
+ loss_bbox = self.loss_bbox(
+ bbox_preds, bbox_targets, bbox_weights, avg_factor=num_total_pos)
+ return loss_cls, loss_bbox, loss_iou
+
+ def get_targets(self,
+ cls_scores_list,
+ bbox_preds_list,
+ gt_bboxes_list,
+ gt_labels_list,
+ img_metas,
+ gt_bboxes_ignore_list=None):
+ """"Compute regression and classification targets for a batch image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_scores_list (list[Tensor]): Box score logits from a single
+ decoder layer for each image with shape [num_query,
+ cls_out_channels].
+ bbox_preds_list (list[Tensor]): Sigmoid outputs from a single
+ decoder layer for each image, with normalized coordinate
+ (cx, cy, w, h) and shape [num_query, 4].
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes for each image
+ with shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels_list (list[Tensor]): Ground truth class indices for each
+ image with shape (num_gts, ).
+ img_metas (list[dict]): List of image meta information.
+ gt_bboxes_ignore_list (list[Tensor], optional): Bounding
+ boxes which can be ignored for each image. Default None.
+
+ Returns:
+ tuple: a tuple containing the following targets.
+
+ - labels_list (list[Tensor]): Labels for all images.
+ - label_weights_list (list[Tensor]): Label weights for all \
+ images.
+ - bbox_targets_list (list[Tensor]): BBox targets for all \
+ images.
+ - bbox_weights_list (list[Tensor]): BBox weights for all \
+ images.
+ - num_total_pos (int): Number of positive samples in all \
+ images.
+ - num_total_neg (int): Number of negative samples in all \
+ images.
+ """
+ assert gt_bboxes_ignore_list is None, \
+ 'Only supports for gt_bboxes_ignore setting to None.'
+ num_imgs = len(cls_scores_list)
+ gt_bboxes_ignore_list = [
+ gt_bboxes_ignore_list for _ in range(num_imgs)
+ ]
+
+ (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, pos_inds_list, neg_inds_list) = multi_apply(
+ self._get_target_single, cls_scores_list, bbox_preds_list,
+ gt_bboxes_list, gt_labels_list, img_metas, gt_bboxes_ignore_list)
+ num_total_pos = sum((inds.numel() for inds in pos_inds_list))
+ num_total_neg = sum((inds.numel() for inds in neg_inds_list))
+ return (labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg)
+
+ def _get_target_single(self,
+ cls_score,
+ bbox_pred,
+ gt_bboxes,
+ gt_labels,
+ img_meta,
+ gt_bboxes_ignore=None):
+ """"Compute regression and classification targets for one image.
+
+ Outputs from a single decoder layer of a single feature level are used.
+
+ Args:
+ cls_score (Tensor): Box score logits from a single decoder layer
+ for one image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from a single decoder layer
+ for one image, with normalized coordinate (cx, cy, w, h) and
+ shape [num_query, 4].
+ gt_bboxes (Tensor): Ground truth bboxes for one image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (Tensor): Ground truth class indices for one image
+ with shape (num_gts, ).
+ img_meta (dict): Meta information for one image.
+ gt_bboxes_ignore (Tensor, optional): Bounding boxes
+ which can be ignored. Default None.
+
+ Returns:
+ tuple[Tensor]: a tuple containing the following for one image.
+
+ - labels (Tensor): Labels of each image.
+ - label_weights (Tensor]): Label weights of each image.
+ - bbox_targets (Tensor): BBox targets of each image.
+ - bbox_weights (Tensor): BBox weights of each image.
+ - pos_inds (Tensor): Sampled positive indices for each image.
+ - neg_inds (Tensor): Sampled negative indices for each image.
+ """
+
+ num_bboxes = bbox_pred.size(0)
+ # assigner and sampler
+ assign_result = self.assigner.assign(bbox_pred, cls_score, gt_bboxes,
+ gt_labels, img_meta,
+ gt_bboxes_ignore)
+ sampling_result = self.sampler.sample(assign_result, bbox_pred,
+ gt_bboxes)
+ pos_inds = sampling_result.pos_inds
+ neg_inds = sampling_result.neg_inds
+
+ # label targets
+ labels = gt_bboxes.new_full((num_bboxes, ),
+ self.num_classes,
+ dtype=torch.long)
+ labels[pos_inds] = gt_labels[sampling_result.pos_assigned_gt_inds]
+ label_weights = gt_bboxes.new_ones(num_bboxes)
+
+ # bbox targets
+ bbox_targets = torch.zeros_like(bbox_pred)
+ bbox_weights = torch.zeros_like(bbox_pred)
+ bbox_weights[pos_inds] = 1.0
+ img_h, img_w, _ = img_meta['img_shape']
+
+ # DETR regress the relative position of boxes (cxcywh) in the image.
+ # Thus the learning target should be normalized by the image size, also
+ # the box format should be converted from defaultly x1y1x2y2 to cxcywh.
+ factor = bbox_pred.new_tensor([img_w, img_h, img_w,
+ img_h]).unsqueeze(0)
+ pos_gt_bboxes_normalized = sampling_result.pos_gt_bboxes / factor
+ pos_gt_bboxes_targets = bbox_xyxy_to_cxcywh(pos_gt_bboxes_normalized)
+ bbox_targets[pos_inds] = pos_gt_bboxes_targets
+ return (labels, label_weights, bbox_targets, bbox_weights, pos_inds,
+ neg_inds)
+
+ # over-write because img_metas are needed as inputs for bbox_head.
+ def forward_train(self,
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=None,
+ proposal_cfg=None,
+ **kwargs):
+ """Forward function for training mode.
+
+ Args:
+ x (list[Tensor]): Features from backbone.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes (Tensor): Ground truth bboxes of the image,
+ shape (num_gts, 4).
+ gt_labels (Tensor): Ground truth labels of each box,
+ shape (num_gts,).
+ gt_bboxes_ignore (Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+ proposal_cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert proposal_cfg is None, '"proposal_cfg" must be None'
+ outs = self(x, img_metas)
+ if gt_labels is None:
+ loss_inputs = outs + (gt_bboxes, img_metas)
+ else:
+ loss_inputs = outs + (gt_bboxes, gt_labels, img_metas)
+ losses = self.loss(*loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+ return losses
+
+ @force_fp32(apply_to=('all_cls_scores_list', 'all_bbox_preds_list'))
+ def get_bboxes(self,
+ all_cls_scores_list,
+ all_bbox_preds_list,
+ img_metas,
+ rescale=False):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ all_cls_scores_list (list[Tensor]): Classification outputs
+ for each feature level. Each is a 4D-tensor with shape
+ [nb_dec, bs, num_query, cls_out_channels].
+ all_bbox_preds_list (list[Tensor]): Sigmoid regression
+ outputs for each feature level. Each is a 4D-tensor with
+ normalized coordinate format (cx, cy, w, h) and shape
+ [nb_dec, bs, num_query, 4].
+ img_metas (list[dict]): Meta information of each image.
+ rescale (bool, optional): If True, return boxes in original
+ image space. Default False.
+
+ Returns:
+ list[list[Tensor, Tensor]]: Each item in result_list is 2-tuple. \
+ The first item is an (n, 5) tensor, where the first 4 columns \
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the \
+ 5-th column is a score between 0 and 1. The second item is a \
+ (n,) tensor where each item is the predicted class label of \
+ the corresponding box.
+ """
+ # NOTE defaultly only using outputs from the last feature level,
+ # and only the outputs from the last decoder layer is used.
+ cls_scores = all_cls_scores_list[-1][-1]
+ bbox_preds = all_bbox_preds_list[-1][-1]
+
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score = cls_scores[img_id]
+ bbox_pred = bbox_preds[img_id]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ proposals = self._get_bboxes_single(cls_score, bbox_pred,
+ img_shape, scale_factor,
+ rescale)
+ result_list.append(proposals)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False):
+ """Transform outputs from the last decoder layer into bbox predictions
+ for each image.
+
+ Args:
+ cls_score (Tensor): Box score logits from the last decoder layer
+ for each image. Shape [num_query, cls_out_channels].
+ bbox_pred (Tensor): Sigmoid outputs from the last decoder layer
+ for each image, with coordinate format (cx, cy, w, h) and
+ shape [num_query, 4].
+ img_shape (tuple[int]): Shape of input image, (height, width, 3).
+ scale_factor (ndarray, optional): Scale factor of the image arange
+ as (w_scale, h_scale, w_scale, h_scale).
+ rescale (bool, optional): If True, return boxes in original image
+ space. Default False.
+
+ Returns:
+ tuple[Tensor]: Results of detected bboxes and labels.
+
+ - det_bboxes: Predicted bboxes with shape [num_query, 5], \
+ where the first 4 columns are bounding box positions \
+ (tl_x, tl_y, br_x, br_y) and the 5-th column are scores \
+ between 0 and 1.
+ - det_labels: Predicted labels of the corresponding box with \
+ shape [num_query].
+ """
+ assert len(cls_score) == len(bbox_pred)
+ # exclude background
+ scores, det_labels = F.softmax(cls_score, dim=-1)[..., :-1].max(-1)
+ det_bboxes = bbox_cxcywh_to_xyxy(bbox_pred)
+ det_bboxes[:, 0::2] = det_bboxes[:, 0::2] * img_shape[1]
+ det_bboxes[:, 1::2] = det_bboxes[:, 1::2] * img_shape[0]
+ det_bboxes[:, 0::2].clamp_(min=0, max=img_shape[1])
+ det_bboxes[:, 1::2].clamp_(min=0, max=img_shape[0])
+ if rescale:
+ det_bboxes /= det_bboxes.new_tensor(scale_factor)
+ det_bboxes = torch.cat((det_bboxes, scores.unsqueeze(1)), -1)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/dense_heads/vfnet_head.py b/mmdet/models/dense_heads/vfnet_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..7243bb62893839568ec51928d88a5ad40b02a66c
--- /dev/null
+++ b/mmdet/models/dense_heads/vfnet_head.py
@@ -0,0 +1,794 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Scale, bias_init_with_prob, normal_init
+from mmcv.ops import DeformConv2d
+from mmcv.runner import force_fp32
+
+from mmdet.core import (bbox2distance, bbox_overlaps, build_anchor_generator,
+ build_assigner, build_sampler, distance2bbox,
+ multi_apply, multiclass_nms, reduce_mean)
+from ..builder import HEADS, build_loss
+from .atss_head import ATSSHead
+from .fcos_head import FCOSHead
+
+INF = 1e8
+
+
+@HEADS.register_module()
+class VFNetHead(ATSSHead, FCOSHead):
+ """Head of `VarifocalNet (VFNet): An IoU-aware Dense Object
+ Detector.`_.
+
+ The VFNet predicts IoU-aware classification scores which mix the
+ object presence confidence and object localization accuracy as the
+ detection score. It is built on the FCOS architecture and uses ATSS
+ for defining positive/negative training examples. The VFNet is trained
+ with Varifocal Loss and empolys star-shaped deformable convolution to
+ extract features for a bbox.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ regress_ranges (tuple[tuple[int, int]]): Regress range of multiple
+ level points.
+ center_sampling (bool): If true, use center sampling. Default: False.
+ center_sample_radius (float): Radius of center sampling. Default: 1.5.
+ sync_num_pos (bool): If true, synchronize the number of positive
+ examples across GPUs. Default: True
+ gradient_mul (float): The multiplier to gradients from bbox refinement
+ and recognition. Default: 0.1.
+ bbox_norm_type (str): The bbox normalization type, 'reg_denom' or
+ 'stride'. Default: reg_denom
+ loss_cls_fl (dict): Config of focal loss.
+ use_vfl (bool): If true, use varifocal loss for training.
+ Default: True.
+ loss_cls (dict): Config of varifocal loss.
+ loss_bbox (dict): Config of localization loss, GIoU Loss.
+ loss_bbox (dict): Config of localization refinement loss, GIoU Loss.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: norm_cfg=dict(type='GN', num_groups=32,
+ requires_grad=True).
+ use_atss (bool): If true, use ATSS to define positive/negative
+ examples. Default: True.
+ anchor_generator (dict): Config of anchor generator for ATSS.
+
+ Example:
+ >>> self = VFNetHead(11, 7)
+ >>> feats = [torch.rand(1, 7, s, s) for s in [4, 8, 16, 32, 64]]
+ >>> cls_score, bbox_pred, bbox_pred_refine= self.forward(feats)
+ >>> assert len(cls_score) == len(self.scales)
+ """ # noqa: E501
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ regress_ranges=((-1, 64), (64, 128), (128, 256), (256, 512),
+ (512, INF)),
+ center_sampling=False,
+ center_sample_radius=1.5,
+ sync_num_pos=True,
+ gradient_mul=0.1,
+ bbox_norm_type='reg_denom',
+ loss_cls_fl=dict(
+ type='FocalLoss',
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ loss_weight=1.0),
+ use_vfl=True,
+ loss_cls=dict(
+ type='VarifocalLoss',
+ use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ loss_weight=1.0),
+ loss_bbox=dict(type='GIoULoss', loss_weight=1.5),
+ loss_bbox_refine=dict(type='GIoULoss', loss_weight=2.0),
+ norm_cfg=dict(type='GN', num_groups=32, requires_grad=True),
+ use_atss=True,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ ratios=[1.0],
+ octave_base_scale=8,
+ scales_per_octave=1,
+ center_offset=0.0,
+ strides=[8, 16, 32, 64, 128]),
+ **kwargs):
+ # dcn base offsets, adapted from reppoints_head.py
+ self.num_dconv_points = 9
+ self.dcn_kernel = int(np.sqrt(self.num_dconv_points))
+ self.dcn_pad = int((self.dcn_kernel - 1) / 2)
+ dcn_base = np.arange(-self.dcn_pad,
+ self.dcn_pad + 1).astype(np.float64)
+ dcn_base_y = np.repeat(dcn_base, self.dcn_kernel)
+ dcn_base_x = np.tile(dcn_base, self.dcn_kernel)
+ dcn_base_offset = np.stack([dcn_base_y, dcn_base_x], axis=1).reshape(
+ (-1))
+ self.dcn_base_offset = torch.tensor(dcn_base_offset).view(1, -1, 1, 1)
+
+ super(FCOSHead, self).__init__(
+ num_classes, in_channels, norm_cfg=norm_cfg, **kwargs)
+ self.regress_ranges = regress_ranges
+ self.reg_denoms = [
+ regress_range[-1] for regress_range in regress_ranges
+ ]
+ self.reg_denoms[-1] = self.reg_denoms[-2] * 2
+ self.center_sampling = center_sampling
+ self.center_sample_radius = center_sample_radius
+ self.sync_num_pos = sync_num_pos
+ self.bbox_norm_type = bbox_norm_type
+ self.gradient_mul = gradient_mul
+ self.use_vfl = use_vfl
+ if self.use_vfl:
+ self.loss_cls = build_loss(loss_cls)
+ else:
+ self.loss_cls = build_loss(loss_cls_fl)
+ self.loss_bbox = build_loss(loss_bbox)
+ self.loss_bbox_refine = build_loss(loss_bbox_refine)
+
+ # for getting ATSS targets
+ self.use_atss = use_atss
+ self.use_sigmoid_cls = loss_cls.get('use_sigmoid', False)
+ self.anchor_generator = build_anchor_generator(anchor_generator)
+ self.anchor_center_offset = anchor_generator['center_offset']
+ self.num_anchors = self.anchor_generator.num_base_anchors[0]
+ self.sampling = False
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ super(FCOSHead, self)._init_cls_convs()
+ super(FCOSHead, self)._init_reg_convs()
+ self.relu = nn.ReLU(inplace=True)
+ self.vfnet_reg_conv = ConvModule(
+ self.feat_channels,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=self.conv_bias)
+ self.vfnet_reg = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.scales = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ self.vfnet_reg_refine_dconv = DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ self.dcn_kernel,
+ 1,
+ padding=self.dcn_pad)
+ self.vfnet_reg_refine = nn.Conv2d(self.feat_channels, 4, 3, padding=1)
+ self.scales_refine = nn.ModuleList([Scale(1.0) for _ in self.strides])
+
+ self.vfnet_cls_dconv = DeformConv2d(
+ self.feat_channels,
+ self.feat_channels,
+ self.dcn_kernel,
+ 1,
+ padding=self.dcn_pad)
+ self.vfnet_cls = nn.Conv2d(
+ self.feat_channels, self.cls_out_channels, 3, padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.cls_convs:
+ if isinstance(m.conv, nn.Conv2d):
+ normal_init(m.conv, std=0.01)
+ for m in self.reg_convs:
+ if isinstance(m.conv, nn.Conv2d):
+ normal_init(m.conv, std=0.01)
+ normal_init(self.vfnet_reg_conv.conv, std=0.01)
+ normal_init(self.vfnet_reg, std=0.01)
+ normal_init(self.vfnet_reg_refine_dconv, std=0.01)
+ normal_init(self.vfnet_reg_refine, std=0.01)
+ normal_init(self.vfnet_cls_dconv, std=0.01)
+ bias_cls = bias_init_with_prob(0.01)
+ normal_init(self.vfnet_cls, std=0.01, bias=bias_cls)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box offsets for each
+ scale level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ bbox_preds_refine (list[Tensor]): Refined Box offsets for
+ each scale level, each is a 4D-tensor, the channel
+ number is num_points * 4.
+ """
+ return multi_apply(self.forward_single, feats, self.scales,
+ self.scales_refine, self.strides, self.reg_denoms)
+
+ def forward_single(self, x, scale, scale_refine, stride, reg_denom):
+ """Forward features of a single scale level.
+
+ Args:
+ x (Tensor): FPN feature maps of the specified stride.
+ scale (:obj: `mmcv.cnn.Scale`): Learnable scale module to resize
+ the bbox prediction.
+ scale_refine (:obj: `mmcv.cnn.Scale`): Learnable scale module to
+ resize the refined bbox prediction.
+ stride (int): The corresponding stride for feature maps,
+ used to normalize the bbox prediction when
+ bbox_norm_type = 'stride'.
+ reg_denom (int): The corresponding regression range for feature
+ maps, only used to normalize the bbox prediction when
+ bbox_norm_type = 'reg_denom'.
+
+ Returns:
+ tuple: iou-aware cls scores for each box, bbox predictions and
+ refined bbox predictions of input feature maps.
+ """
+ cls_feat = x
+ reg_feat = x
+
+ for cls_layer in self.cls_convs:
+ cls_feat = cls_layer(cls_feat)
+
+ for reg_layer in self.reg_convs:
+ reg_feat = reg_layer(reg_feat)
+
+ # predict the bbox_pred of different level
+ reg_feat_init = self.vfnet_reg_conv(reg_feat)
+ if self.bbox_norm_type == 'reg_denom':
+ bbox_pred = scale(
+ self.vfnet_reg(reg_feat_init)).float().exp() * reg_denom
+ elif self.bbox_norm_type == 'stride':
+ bbox_pred = scale(
+ self.vfnet_reg(reg_feat_init)).float().exp() * stride
+ else:
+ raise NotImplementedError
+
+ # compute star deformable convolution offsets
+ # converting dcn_offset to reg_feat.dtype thus VFNet can be
+ # trained with FP16
+ dcn_offset = self.star_dcn_offset(bbox_pred, self.gradient_mul,
+ stride).to(reg_feat.dtype)
+
+ # refine the bbox_pred
+ reg_feat = self.relu(self.vfnet_reg_refine_dconv(reg_feat, dcn_offset))
+ bbox_pred_refine = scale_refine(
+ self.vfnet_reg_refine(reg_feat)).float().exp()
+ bbox_pred_refine = bbox_pred_refine * bbox_pred.detach()
+
+ # predict the iou-aware cls score
+ cls_feat = self.relu(self.vfnet_cls_dconv(cls_feat, dcn_offset))
+ cls_score = self.vfnet_cls(cls_feat)
+
+ return cls_score, bbox_pred, bbox_pred_refine
+
+ def star_dcn_offset(self, bbox_pred, gradient_mul, stride):
+ """Compute the star deformable conv offsets.
+
+ Args:
+ bbox_pred (Tensor): Predicted bbox distance offsets (l, r, t, b).
+ gradient_mul (float): Gradient multiplier.
+ stride (int): The corresponding stride for feature maps,
+ used to project the bbox onto the feature map.
+
+ Returns:
+ dcn_offsets (Tensor): The offsets for deformable convolution.
+ """
+ dcn_base_offset = self.dcn_base_offset.type_as(bbox_pred)
+ bbox_pred_grad_mul = (1 - gradient_mul) * bbox_pred.detach() + \
+ gradient_mul * bbox_pred
+ # map to the feature map scale
+ bbox_pred_grad_mul = bbox_pred_grad_mul / stride
+ N, C, H, W = bbox_pred.size()
+
+ x1 = bbox_pred_grad_mul[:, 0, :, :]
+ y1 = bbox_pred_grad_mul[:, 1, :, :]
+ x2 = bbox_pred_grad_mul[:, 2, :, :]
+ y2 = bbox_pred_grad_mul[:, 3, :, :]
+ bbox_pred_grad_mul_offset = bbox_pred.new_zeros(
+ N, 2 * self.num_dconv_points, H, W)
+ bbox_pred_grad_mul_offset[:, 0, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 1, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 2, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 4, :, :] = -1.0 * y1 # -y1
+ bbox_pred_grad_mul_offset[:, 5, :, :] = x2 # x2
+ bbox_pred_grad_mul_offset[:, 7, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 11, :, :] = x2 # x2
+ bbox_pred_grad_mul_offset[:, 12, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 13, :, :] = -1.0 * x1 # -x1
+ bbox_pred_grad_mul_offset[:, 14, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 16, :, :] = y2 # y2
+ bbox_pred_grad_mul_offset[:, 17, :, :] = x2 # x2
+ dcn_offset = bbox_pred_grad_mul_offset - dcn_base_offset
+
+ return dcn_offset
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'bbox_preds_refine'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ bbox_preds_refine,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_preds (list[Tensor]): Box offsets for each
+ scale level, each is a 4D-tensor, the channel number is
+ num_points * 4.
+ bbox_preds_refine (list[Tensor]): Refined Box offsets for
+ each scale level, each is a 4D-tensor, the channel
+ number is num_points * 4.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ Default: None.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine)
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ all_level_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
+ bbox_preds[0].device)
+ labels, label_weights, bbox_targets, bbox_weights = self.get_targets(
+ cls_scores, all_level_points, gt_bboxes, gt_labels, img_metas,
+ gt_bboxes_ignore)
+
+ num_imgs = cls_scores[0].size(0)
+ # flatten cls_scores, bbox_preds and bbox_preds_refine
+ flatten_cls_scores = [
+ cls_score.permute(0, 2, 3,
+ 1).reshape(-1,
+ self.cls_out_channels).contiguous()
+ for cls_score in cls_scores
+ ]
+ flatten_bbox_preds = [
+ bbox_pred.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
+ for bbox_pred in bbox_preds
+ ]
+ flatten_bbox_preds_refine = [
+ bbox_pred_refine.permute(0, 2, 3, 1).reshape(-1, 4).contiguous()
+ for bbox_pred_refine in bbox_preds_refine
+ ]
+ flatten_cls_scores = torch.cat(flatten_cls_scores)
+ flatten_bbox_preds = torch.cat(flatten_bbox_preds)
+ flatten_bbox_preds_refine = torch.cat(flatten_bbox_preds_refine)
+ flatten_labels = torch.cat(labels)
+ flatten_bbox_targets = torch.cat(bbox_targets)
+ # repeat points to align with bbox_preds
+ flatten_points = torch.cat(
+ [points.repeat(num_imgs, 1) for points in all_level_points])
+
+ # FG cat_id: [0, num_classes - 1], BG cat_id: num_classes
+ bg_class_ind = self.num_classes
+ pos_inds = torch.where(
+ ((flatten_labels >= 0) & (flatten_labels < bg_class_ind)) > 0)[0]
+ num_pos = len(pos_inds)
+
+ pos_bbox_preds = flatten_bbox_preds[pos_inds]
+ pos_bbox_preds_refine = flatten_bbox_preds_refine[pos_inds]
+ pos_labels = flatten_labels[pos_inds]
+
+ # sync num_pos across all gpus
+ if self.sync_num_pos:
+ num_pos_avg_per_gpu = reduce_mean(
+ pos_inds.new_tensor(num_pos).float()).item()
+ num_pos_avg_per_gpu = max(num_pos_avg_per_gpu, 1.0)
+ else:
+ num_pos_avg_per_gpu = num_pos
+
+ if num_pos > 0:
+ pos_bbox_targets = flatten_bbox_targets[pos_inds]
+ pos_points = flatten_points[pos_inds]
+
+ pos_decoded_bbox_preds = distance2bbox(pos_points, pos_bbox_preds)
+ pos_decoded_target_preds = distance2bbox(pos_points,
+ pos_bbox_targets)
+ iou_targets_ini = bbox_overlaps(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds.detach(),
+ is_aligned=True).clamp(min=1e-6)
+ bbox_weights_ini = iou_targets_ini.clone().detach()
+ iou_targets_ini_avg_per_gpu = reduce_mean(
+ bbox_weights_ini.sum()).item()
+ bbox_avg_factor_ini = max(iou_targets_ini_avg_per_gpu, 1.0)
+ loss_bbox = self.loss_bbox(
+ pos_decoded_bbox_preds,
+ pos_decoded_target_preds.detach(),
+ weight=bbox_weights_ini,
+ avg_factor=bbox_avg_factor_ini)
+
+ pos_decoded_bbox_preds_refine = \
+ distance2bbox(pos_points, pos_bbox_preds_refine)
+ iou_targets_rf = bbox_overlaps(
+ pos_decoded_bbox_preds_refine,
+ pos_decoded_target_preds.detach(),
+ is_aligned=True).clamp(min=1e-6)
+ bbox_weights_rf = iou_targets_rf.clone().detach()
+ iou_targets_rf_avg_per_gpu = reduce_mean(
+ bbox_weights_rf.sum()).item()
+ bbox_avg_factor_rf = max(iou_targets_rf_avg_per_gpu, 1.0)
+ loss_bbox_refine = self.loss_bbox_refine(
+ pos_decoded_bbox_preds_refine,
+ pos_decoded_target_preds.detach(),
+ weight=bbox_weights_rf,
+ avg_factor=bbox_avg_factor_rf)
+
+ # build IoU-aware cls_score targets
+ if self.use_vfl:
+ pos_ious = iou_targets_rf.clone().detach()
+ cls_iou_targets = torch.zeros_like(flatten_cls_scores)
+ cls_iou_targets[pos_inds, pos_labels] = pos_ious
+ else:
+ loss_bbox = pos_bbox_preds.sum() * 0
+ loss_bbox_refine = pos_bbox_preds_refine.sum() * 0
+ if self.use_vfl:
+ cls_iou_targets = torch.zeros_like(flatten_cls_scores)
+
+ if self.use_vfl:
+ loss_cls = self.loss_cls(
+ flatten_cls_scores,
+ cls_iou_targets,
+ avg_factor=num_pos_avg_per_gpu)
+ else:
+ loss_cls = self.loss_cls(
+ flatten_cls_scores,
+ flatten_labels,
+ weight=label_weights,
+ avg_factor=num_pos_avg_per_gpu)
+
+ return dict(
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ loss_bbox_rf=loss_bbox_refine)
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'bbox_preds_refine'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ bbox_preds_refine,
+ img_metas,
+ cfg=None,
+ rescale=None,
+ with_nms=True):
+ """Transform network outputs for a batch into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level with shape (N, num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box offsets for each scale
+ level with shape (N, num_points * 4, H, W).
+ bbox_preds_refine (list[Tensor]): Refined Box offsets for
+ each scale level with shape (N, num_points * 4, H, W).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before returning boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where the first 4 columns
+ are bounding box positions (tl_x, tl_y, br_x, br_y) and the
+ 5-th column is a score between 0 and 1. The second item is a
+ (n,) tensor where each item is the predicted class label of
+ the corresponding box.
+ """
+ assert len(cls_scores) == len(bbox_preds) == len(bbox_preds_refine)
+ num_levels = len(cls_scores)
+
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ mlvl_points = self.get_points(featmap_sizes, bbox_preds[0].dtype,
+ bbox_preds[0].device)
+ result_list = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds_refine[i][img_id].detach()
+ for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ det_bboxes = self._get_bboxes_single(cls_score_list,
+ bbox_pred_list, mlvl_points,
+ img_shape, scale_factor, cfg,
+ rescale, with_nms)
+ result_list.append(det_bboxes)
+ return result_list
+
+ def _get_bboxes_single(self,
+ cls_scores,
+ bbox_preds,
+ mlvl_points,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for a single scale
+ level with shape (num_points * num_classes, H, W).
+ bbox_preds (list[Tensor]): Box offsets for a single scale
+ level with shape (num_points * 4, H, W).
+ mlvl_points (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_points, 4).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arrange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before returning boxes.
+ Default: True.
+
+ Returns:
+ tuple(Tensor):
+ det_bboxes (Tensor): BBox predictions in shape (n, 5), where
+ the first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
+ between 0 and 1.
+ det_labels (Tensor): A (n,) tensor where each item is the
+ predicted class label of the corresponding box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_scores) == len(bbox_preds) == len(mlvl_points)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ for cls_score, bbox_pred, points in zip(cls_scores, bbox_preds,
+ mlvl_points):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ scores = cls_score.permute(1, 2, 0).reshape(
+ -1, self.cls_out_channels).contiguous().sigmoid()
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4).contiguous()
+
+ nms_pre = cfg.get('nms_pre', -1)
+ if 0 < nms_pre < scores.shape[0]:
+ max_scores, _ = scores.max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ points = points[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ bboxes = distance2bbox(points, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ if with_nms:
+ det_bboxes, det_labels = multiclass_nms(mlvl_bboxes, mlvl_scores,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ return det_bboxes, det_labels
+ else:
+ return mlvl_bboxes, mlvl_scores
+
+ def _get_points_single(self,
+ featmap_size,
+ stride,
+ dtype,
+ device,
+ flatten=False):
+ """Get points according to feature map sizes."""
+ h, w = featmap_size
+ x_range = torch.arange(
+ 0, w * stride, stride, dtype=dtype, device=device)
+ y_range = torch.arange(
+ 0, h * stride, stride, dtype=dtype, device=device)
+ y, x = torch.meshgrid(y_range, x_range)
+ # to be compatible with anchor points in ATSS
+ if self.use_atss:
+ points = torch.stack(
+ (x.reshape(-1), y.reshape(-1)), dim=-1) + \
+ stride * self.anchor_center_offset
+ else:
+ points = torch.stack(
+ (x.reshape(-1), y.reshape(-1)), dim=-1) + stride // 2
+ return points
+
+ def get_targets(self, cls_scores, mlvl_points, gt_bboxes, gt_labels,
+ img_metas, gt_bboxes_ignore):
+ """A wrapper for computing ATSS and FCOS targets for points in multiple
+ images.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level with shape (N, num_points * num_classes, H, W).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4).
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level.
+ label_weights (Tensor/None): Label weights of all levels.
+ bbox_targets_list (list[Tensor]): Regression targets of each
+ level, (l, t, r, b).
+ bbox_weights (Tensor/None): Bbox weights of all levels.
+ """
+ if self.use_atss:
+ return self.get_atss_targets(cls_scores, mlvl_points, gt_bboxes,
+ gt_labels, img_metas,
+ gt_bboxes_ignore)
+ else:
+ self.norm_on_bbox = False
+ return self.get_fcos_targets(mlvl_points, gt_bboxes, gt_labels)
+
+ def _get_target_single(self, *args, **kwargs):
+ """Avoid ambiguity in multiple inheritance."""
+ if self.use_atss:
+ return ATSSHead._get_target_single(self, *args, **kwargs)
+ else:
+ return FCOSHead._get_target_single(self, *args, **kwargs)
+
+ def get_fcos_targets(self, points, gt_bboxes_list, gt_labels_list):
+ """Compute FCOS regression and classification targets for points in
+ multiple images.
+
+ Args:
+ points (list[Tensor]): Points of each fpn level, each has shape
+ (num_points, 2).
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels_list (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+
+ Returns:
+ tuple:
+ labels (list[Tensor]): Labels of each level.
+ label_weights: None, to be compatible with ATSS targets.
+ bbox_targets (list[Tensor]): BBox targets of each level.
+ bbox_weights: None, to be compatible with ATSS targets.
+ """
+ labels, bbox_targets = FCOSHead.get_targets(self, points,
+ gt_bboxes_list,
+ gt_labels_list)
+ label_weights = None
+ bbox_weights = None
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_atss_targets(self,
+ cls_scores,
+ mlvl_points,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """A wrapper for computing ATSS targets for points in multiple images.
+
+ Args:
+ cls_scores (list[Tensor]): Box iou-aware scores for each scale
+ level with shape (N, num_points * num_classes, H, W).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ gt_bboxes (list[Tensor]): Ground truth bboxes of each image,
+ each has shape (num_gt, 4).
+ gt_labels (list[Tensor]): Ground truth labels of each box,
+ each has shape (num_gt,).
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | Tensor): Ground truth bboxes to be
+ ignored, shape (num_ignored_gts, 4). Default: None.
+
+ Returns:
+ tuple:
+ labels_list (list[Tensor]): Labels of each level.
+ label_weights (Tensor): Label weights of all levels.
+ bbox_targets_list (list[Tensor]): Regression targets of each
+ level, (l, t, r, b).
+ bbox_weights (Tensor): Bbox weights of all levels.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+
+ cls_reg_targets = ATSSHead.get_targets(
+ self,
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ unmap_outputs=True)
+ if cls_reg_targets is None:
+ return None
+
+ (anchor_list, labels_list, label_weights_list, bbox_targets_list,
+ bbox_weights_list, num_total_pos, num_total_neg) = cls_reg_targets
+
+ bbox_targets_list = [
+ bbox_targets.reshape(-1, 4) for bbox_targets in bbox_targets_list
+ ]
+
+ num_imgs = len(img_metas)
+ # transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format
+ bbox_targets_list = self.transform_bbox_targets(
+ bbox_targets_list, mlvl_points, num_imgs)
+
+ labels_list = [labels.reshape(-1) for labels in labels_list]
+ label_weights_list = [
+ label_weights.reshape(-1) for label_weights in label_weights_list
+ ]
+ bbox_weights_list = [
+ bbox_weights.reshape(-1) for bbox_weights in bbox_weights_list
+ ]
+ label_weights = torch.cat(label_weights_list)
+ bbox_weights = torch.cat(bbox_weights_list)
+ return labels_list, label_weights, bbox_targets_list, bbox_weights
+
+ def transform_bbox_targets(self, decoded_bboxes, mlvl_points, num_imgs):
+ """Transform bbox_targets (x1, y1, x2, y2) into (l, t, r, b) format.
+
+ Args:
+ decoded_bboxes (list[Tensor]): Regression targets of each level,
+ in the form of (x1, y1, x2, y2).
+ mlvl_points (list[Tensor]): Points of each fpn level, each has
+ shape (num_points, 2).
+ num_imgs (int): the number of images in a batch.
+
+ Returns:
+ bbox_targets (list[Tensor]): Regression targets of each level in
+ the form of (l, t, r, b).
+ """
+ # TODO: Re-implemented in Class PointCoder
+ assert len(decoded_bboxes) == len(mlvl_points)
+ num_levels = len(decoded_bboxes)
+ mlvl_points = [points.repeat(num_imgs, 1) for points in mlvl_points]
+ bbox_targets = []
+ for i in range(num_levels):
+ bbox_target = bbox2distance(mlvl_points[i], decoded_bboxes[i])
+ bbox_targets.append(bbox_target)
+
+ return bbox_targets
+
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
+ missing_keys, unexpected_keys, error_msgs):
+ """Override the method in the parent class to avoid changing para's
+ name."""
+ pass
diff --git a/mmdet/models/dense_heads/yolact_head.py b/mmdet/models/dense_heads/yolact_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..10d311f94ee99e1bf65ee3e5827f1699c28a23e3
--- /dev/null
+++ b/mmdet/models/dense_heads/yolact_head.py
@@ -0,0 +1,943 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, xavier_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import build_sampler, fast_nms, images_to_levels, multi_apply
+from ..builder import HEADS, build_loss
+from .anchor_head import AnchorHead
+
+
+@HEADS.register_module()
+class YOLACTHead(AnchorHead):
+ """YOLACT box head used in https://arxiv.org/abs/1904.02689.
+
+ Note that YOLACT head is a light version of RetinaNet head.
+ Four differences are described as follows:
+
+ 1. YOLACT box head has three-times fewer anchors.
+ 2. YOLACT box head shares the convs for box and cls branches.
+ 3. YOLACT box head uses OHEM instead of Focal loss.
+ 4. YOLACT box head predicts a set of mask coefficients for each box.
+
+ Args:
+ num_classes (int): Number of categories excluding the background
+ category.
+ in_channels (int): Number of channels in the input feature map.
+ anchor_generator (dict): Config dict for anchor generator
+ loss_cls (dict): Config of classification loss.
+ loss_bbox (dict): Config of localization loss.
+ num_head_convs (int): Number of the conv layers shared by
+ box and cls branches.
+ num_protos (int): Number of the mask coefficients.
+ use_ohem (bool): If true, ``loss_single_OHEM`` will be used for
+ cls loss calculation. If false, ``loss_single`` will be used.
+ conv_cfg (dict): Dictionary to construct and config conv layer.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ anchor_generator=dict(
+ type='AnchorGenerator',
+ octave_base_scale=3,
+ scales_per_octave=1,
+ ratios=[0.5, 1.0, 2.0],
+ strides=[8, 16, 32, 64, 128]),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ reduction='none',
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1.5),
+ num_head_convs=1,
+ num_protos=32,
+ use_ohem=True,
+ conv_cfg=None,
+ norm_cfg=None,
+ **kwargs):
+ self.num_head_convs = num_head_convs
+ self.num_protos = num_protos
+ self.use_ohem = use_ohem
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ super(YOLACTHead, self).__init__(
+ num_classes,
+ in_channels,
+ loss_cls=loss_cls,
+ loss_bbox=loss_bbox,
+ anchor_generator=anchor_generator,
+ **kwargs)
+ if self.use_ohem:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+ self.sampling = False
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.relu = nn.ReLU(inplace=True)
+ self.head_convs = nn.ModuleList()
+ for i in range(self.num_head_convs):
+ chn = self.in_channels if i == 0 else self.feat_channels
+ self.head_convs.append(
+ ConvModule(
+ chn,
+ self.feat_channels,
+ 3,
+ stride=1,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.conv_cls = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.cls_out_channels,
+ 3,
+ padding=1)
+ self.conv_reg = nn.Conv2d(
+ self.feat_channels, self.num_anchors * 4, 3, padding=1)
+ self.conv_coeff = nn.Conv2d(
+ self.feat_channels,
+ self.num_anchors * self.num_protos,
+ 3,
+ padding=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.head_convs:
+ xavier_init(m.conv, distribution='uniform', bias=0)
+ xavier_init(self.conv_cls, distribution='uniform', bias=0)
+ xavier_init(self.conv_reg, distribution='uniform', bias=0)
+ xavier_init(self.conv_coeff, distribution='uniform', bias=0)
+
+ def forward_single(self, x):
+ """Forward feature of a single scale level.
+
+ Args:
+ x (Tensor): Features of a single scale level.
+
+ Returns:
+ tuple:
+ cls_score (Tensor): Cls scores for a single scale level \
+ the channels number is num_anchors * num_classes.
+ bbox_pred (Tensor): Box energies / deltas for a single scale \
+ level, the channels number is num_anchors * 4.
+ coeff_pred (Tensor): Mask coefficients for a single scale \
+ level, the channels number is num_anchors * num_protos.
+ """
+ for head_conv in self.head_convs:
+ x = head_conv(x)
+ cls_score = self.conv_cls(x)
+ bbox_pred = self.conv_reg(x)
+ coeff_pred = self.conv_coeff(x).tanh()
+ return cls_score, bbox_pred, coeff_pred
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds'))
+ def loss(self,
+ cls_scores,
+ bbox_preds,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """A combination of the func:``AnchorHead.loss`` and
+ func:``SSDHead.loss``.
+
+ When ``self.use_ohem == True``, it functions like ``SSDHead.loss``,
+ otherwise, it follows ``AnchorHead.loss``. Besides, it additionally
+ returns ``sampling_results``.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ Has shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss. Default: None
+
+ Returns:
+ tuple:
+ dict[str, Tensor]: A dictionary of loss components.
+ List[:obj:``SamplingResult``]: Sampler results for each image.
+ """
+ featmap_sizes = [featmap.size()[-2:] for featmap in cls_scores]
+ assert len(featmap_sizes) == self.anchor_generator.num_levels
+
+ device = cls_scores[0].device
+
+ anchor_list, valid_flag_list = self.get_anchors(
+ featmap_sizes, img_metas, device=device)
+ label_channels = self.cls_out_channels if self.use_sigmoid_cls else 1
+ cls_reg_targets = self.get_targets(
+ anchor_list,
+ valid_flag_list,
+ gt_bboxes,
+ img_metas,
+ gt_bboxes_ignore_list=gt_bboxes_ignore,
+ gt_labels_list=gt_labels,
+ label_channels=label_channels,
+ unmap_outputs=not self.use_ohem,
+ return_sampling_results=True)
+ if cls_reg_targets is None:
+ return None
+ (labels_list, label_weights_list, bbox_targets_list, bbox_weights_list,
+ num_total_pos, num_total_neg, sampling_results) = cls_reg_targets
+
+ if self.use_ohem:
+ num_images = len(img_metas)
+ all_cls_scores = torch.cat([
+ s.permute(0, 2, 3, 1).reshape(
+ num_images, -1, self.cls_out_channels) for s in cls_scores
+ ], 1)
+ all_labels = torch.cat(labels_list, -1).view(num_images, -1)
+ all_label_weights = torch.cat(label_weights_list,
+ -1).view(num_images, -1)
+ all_bbox_preds = torch.cat([
+ b.permute(0, 2, 3, 1).reshape(num_images, -1, 4)
+ for b in bbox_preds
+ ], -2)
+ all_bbox_targets = torch.cat(bbox_targets_list,
+ -2).view(num_images, -1, 4)
+ all_bbox_weights = torch.cat(bbox_weights_list,
+ -2).view(num_images, -1, 4)
+
+ # concat all level anchors to a single tensor
+ all_anchors = []
+ for i in range(num_images):
+ all_anchors.append(torch.cat(anchor_list[i]))
+
+ # check NaN and Inf
+ assert torch.isfinite(all_cls_scores).all().item(), \
+ 'classification scores become infinite or NaN!'
+ assert torch.isfinite(all_bbox_preds).all().item(), \
+ 'bbox predications become infinite or NaN!'
+
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single_OHEM,
+ all_cls_scores,
+ all_bbox_preds,
+ all_anchors,
+ all_labels,
+ all_label_weights,
+ all_bbox_targets,
+ all_bbox_weights,
+ num_total_samples=num_total_pos)
+ else:
+ num_total_samples = (
+ num_total_pos +
+ num_total_neg if self.sampling else num_total_pos)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+ # concat all level anchors and flags to a single tensor
+ concat_anchor_list = []
+ for i in range(len(anchor_list)):
+ concat_anchor_list.append(torch.cat(anchor_list[i]))
+ all_anchor_list = images_to_levels(concat_anchor_list,
+ num_level_anchors)
+ losses_cls, losses_bbox = multi_apply(
+ self.loss_single,
+ cls_scores,
+ bbox_preds,
+ all_anchor_list,
+ labels_list,
+ label_weights_list,
+ bbox_targets_list,
+ bbox_weights_list,
+ num_total_samples=num_total_samples)
+
+ return dict(
+ loss_cls=losses_cls, loss_bbox=losses_bbox), sampling_results
+
+ def loss_single_OHEM(self, cls_score, bbox_pred, anchors, labels,
+ label_weights, bbox_targets, bbox_weights,
+ num_total_samples):
+ """"See func:``SSDHead.loss``."""
+ loss_cls_all = self.loss_cls(cls_score, labels, label_weights)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ pos_inds = ((labels >= 0) & (labels < self.num_classes)).nonzero(
+ as_tuple=False).reshape(-1)
+ neg_inds = (labels == self.num_classes).nonzero(
+ as_tuple=False).view(-1)
+
+ num_pos_samples = pos_inds.size(0)
+ if num_pos_samples == 0:
+ num_neg_samples = neg_inds.size(0)
+ else:
+ num_neg_samples = self.train_cfg.neg_pos_ratio * num_pos_samples
+ if num_neg_samples > neg_inds.size(0):
+ num_neg_samples = neg_inds.size(0)
+ topk_loss_cls_neg, _ = loss_cls_all[neg_inds].topk(num_neg_samples)
+ loss_cls_pos = loss_cls_all[pos_inds].sum()
+ loss_cls_neg = topk_loss_cls_neg.sum()
+ loss_cls = (loss_cls_pos + loss_cls_neg) / num_total_samples
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, it
+ # decodes the already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(anchors, bbox_pred)
+ loss_bbox = self.loss_bbox(
+ bbox_pred,
+ bbox_targets,
+ bbox_weights,
+ avg_factor=num_total_samples)
+ return loss_cls[None], loss_bbox
+
+ @force_fp32(apply_to=('cls_scores', 'bbox_preds', 'coeff_preds'))
+ def get_bboxes(self,
+ cls_scores,
+ bbox_preds,
+ coeff_preds,
+ img_metas,
+ cfg=None,
+ rescale=False):
+ """"Similiar to func:``AnchorHead.get_bboxes``, but additionally
+ processes coeff_preds.
+
+ Args:
+ cls_scores (list[Tensor]): Box scores for each scale level
+ with shape (N, num_anchors * num_classes, H, W)
+ bbox_preds (list[Tensor]): Box energies / deltas for each scale
+ level with shape (N, num_anchors * 4, H, W)
+ coeff_preds (list[Tensor]): Mask coefficients for each scale
+ level with shape (N, num_anchors * num_protos, H, W)
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ list[tuple[Tensor, Tensor, Tensor]]: Each item in result_list is
+ a 3-tuple. The first item is an (n, 5) tensor, where the
+ first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score
+ between 0 and 1. The second item is an (n,) tensor where each
+ item is the predicted class label of the corresponding box.
+ The third item is an (n, num_protos) tensor where each item
+ is the predicted mask coefficients of instance inside the
+ corresponding box.
+ """
+ assert len(cls_scores) == len(bbox_preds)
+ num_levels = len(cls_scores)
+
+ device = cls_scores[0].device
+ featmap_sizes = [cls_scores[i].shape[-2:] for i in range(num_levels)]
+ mlvl_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device=device)
+
+ det_bboxes = []
+ det_labels = []
+ det_coeffs = []
+ for img_id in range(len(img_metas)):
+ cls_score_list = [
+ cls_scores[i][img_id].detach() for i in range(num_levels)
+ ]
+ bbox_pred_list = [
+ bbox_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ coeff_pred_list = [
+ coeff_preds[i][img_id].detach() for i in range(num_levels)
+ ]
+ img_shape = img_metas[img_id]['img_shape']
+ scale_factor = img_metas[img_id]['scale_factor']
+ bbox_res = self._get_bboxes_single(cls_score_list, bbox_pred_list,
+ coeff_pred_list, mlvl_anchors,
+ img_shape, scale_factor, cfg,
+ rescale)
+ det_bboxes.append(bbox_res[0])
+ det_labels.append(bbox_res[1])
+ det_coeffs.append(bbox_res[2])
+ return det_bboxes, det_labels, det_coeffs
+
+ def _get_bboxes_single(self,
+ cls_score_list,
+ bbox_pred_list,
+ coeff_preds_list,
+ mlvl_anchors,
+ img_shape,
+ scale_factor,
+ cfg,
+ rescale=False):
+ """"Similiar to func:``AnchorHead._get_bboxes_single``, but
+ additionally processes coeff_preds_list and uses fast NMS instead of
+ traditional NMS.
+
+ Args:
+ cls_score_list (list[Tensor]): Box scores for a single scale level
+ Has shape (num_anchors * num_classes, H, W).
+ bbox_pred_list (list[Tensor]): Box energies / deltas for a single
+ scale level with shape (num_anchors * 4, H, W).
+ coeff_preds_list (list[Tensor]): Mask coefficients for a single
+ scale level with shape (num_anchors * num_protos, H, W).
+ mlvl_anchors (list[Tensor]): Box reference for a single scale level
+ with shape (num_total_anchors, 4).
+ img_shape (tuple[int]): Shape of the input image,
+ (height, width, 3).
+ scale_factor (ndarray): Scale factor of the image arange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+
+ Returns:
+ tuple[Tensor, Tensor, Tensor]: The first item is an (n, 5) tensor,
+ where the first 4 columns are bounding box positions
+ (tl_x, tl_y, br_x, br_y) and the 5-th column is a score between
+ 0 and 1. The second item is an (n,) tensor where each item is
+ the predicted class label of the corresponding box. The third
+ item is an (n, num_protos) tensor where each item is the
+ predicted mask coefficients of instance inside the
+ corresponding box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(cls_score_list) == len(bbox_pred_list) == len(mlvl_anchors)
+ mlvl_bboxes = []
+ mlvl_scores = []
+ mlvl_coeffs = []
+ for cls_score, bbox_pred, coeff_pred, anchors in \
+ zip(cls_score_list, bbox_pred_list,
+ coeff_preds_list, mlvl_anchors):
+ assert cls_score.size()[-2:] == bbox_pred.size()[-2:]
+ cls_score = cls_score.permute(1, 2,
+ 0).reshape(-1, self.cls_out_channels)
+ if self.use_sigmoid_cls:
+ scores = cls_score.sigmoid()
+ else:
+ scores = cls_score.softmax(-1)
+ bbox_pred = bbox_pred.permute(1, 2, 0).reshape(-1, 4)
+ coeff_pred = coeff_pred.permute(1, 2,
+ 0).reshape(-1, self.num_protos)
+ nms_pre = cfg.get('nms_pre', -1)
+ if nms_pre > 0 and scores.shape[0] > nms_pre:
+ # Get maximum scores for foreground classes.
+ if self.use_sigmoid_cls:
+ max_scores, _ = scores.max(dim=1)
+ else:
+ # remind that we set FG labels to [0, num_class-1]
+ # since mmdet v2.0
+ # BG cat_id: num_class
+ max_scores, _ = scores[:, :-1].max(dim=1)
+ _, topk_inds = max_scores.topk(nms_pre)
+ anchors = anchors[topk_inds, :]
+ bbox_pred = bbox_pred[topk_inds, :]
+ scores = scores[topk_inds, :]
+ coeff_pred = coeff_pred[topk_inds, :]
+ bboxes = self.bbox_coder.decode(
+ anchors, bbox_pred, max_shape=img_shape)
+ mlvl_bboxes.append(bboxes)
+ mlvl_scores.append(scores)
+ mlvl_coeffs.append(coeff_pred)
+ mlvl_bboxes = torch.cat(mlvl_bboxes)
+ if rescale:
+ mlvl_bboxes /= mlvl_bboxes.new_tensor(scale_factor)
+ mlvl_scores = torch.cat(mlvl_scores)
+ mlvl_coeffs = torch.cat(mlvl_coeffs)
+ if self.use_sigmoid_cls:
+ # Add a dummy background class to the backend when using sigmoid
+ # remind that we set FG labels to [0, num_class-1] since mmdet v2.0
+ # BG cat_id: num_class
+ padding = mlvl_scores.new_zeros(mlvl_scores.shape[0], 1)
+ mlvl_scores = torch.cat([mlvl_scores, padding], dim=1)
+ det_bboxes, det_labels, det_coeffs = fast_nms(mlvl_bboxes, mlvl_scores,
+ mlvl_coeffs,
+ cfg.score_thr,
+ cfg.iou_thr, cfg.top_k,
+ cfg.max_per_img)
+ return det_bboxes, det_labels, det_coeffs
+
+
+@HEADS.register_module()
+class YOLACTSegmHead(nn.Module):
+ """YOLACT segmentation head used in https://arxiv.org/abs/1904.02689.
+
+ Apply a semantic segmentation loss on feature space using layers that are
+ only evaluated during training to increase performance with no speed
+ penalty.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ num_classes (int): Number of categories excluding the background
+ category.
+ loss_segm (dict): Config of semantic segmentation loss.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels=256,
+ loss_segm=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0)):
+ super(YOLACTSegmHead, self).__init__()
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.loss_segm = build_loss(loss_segm)
+ self._init_layers()
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ """Initialize layers of the head."""
+ self.segm_conv = nn.Conv2d(
+ self.in_channels, self.num_classes, kernel_size=1)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ xavier_init(self.segm_conv, distribution='uniform')
+
+ def forward(self, x):
+ """Forward feature from the upstream network.
+
+ Args:
+ x (Tensor): Feature from the upstream network, which is
+ a 4D-tensor.
+
+ Returns:
+ Tensor: Predicted semantic segmentation map with shape
+ (N, num_classes, H, W).
+ """
+ return self.segm_conv(x)
+
+ @force_fp32(apply_to=('segm_pred', ))
+ def loss(self, segm_pred, gt_masks, gt_labels):
+ """Compute loss of the head.
+
+ Args:
+ segm_pred (list[Tensor]): Predicted semantic segmentation map
+ with shape (N, num_classes, H, W).
+ gt_masks (list[Tensor]): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_labels (list[Tensor]): Class indices corresponding to each box.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ loss_segm = []
+ num_imgs, num_classes, mask_h, mask_w = segm_pred.size()
+ for idx in range(num_imgs):
+ cur_segm_pred = segm_pred[idx]
+ cur_gt_masks = gt_masks[idx].float()
+ cur_gt_labels = gt_labels[idx]
+ segm_targets = self.get_targets(cur_segm_pred, cur_gt_masks,
+ cur_gt_labels)
+ if segm_targets is None:
+ loss = self.loss_segm(cur_segm_pred,
+ torch.zeros_like(cur_segm_pred),
+ torch.zeros_like(cur_segm_pred))
+ else:
+ loss = self.loss_segm(
+ cur_segm_pred,
+ segm_targets,
+ avg_factor=num_imgs * mask_h * mask_w)
+ loss_segm.append(loss)
+ return dict(loss_segm=loss_segm)
+
+ def get_targets(self, segm_pred, gt_masks, gt_labels):
+ """Compute semantic segmentation targets for each image.
+
+ Args:
+ segm_pred (Tensor): Predicted semantic segmentation map
+ with shape (num_classes, H, W).
+ gt_masks (Tensor): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_labels (Tensor): Class indices corresponding to each box.
+
+ Returns:
+ Tensor: Semantic segmentation targets with shape
+ (num_classes, H, W).
+ """
+ if gt_masks.size(0) == 0:
+ return None
+ num_classes, mask_h, mask_w = segm_pred.size()
+ with torch.no_grad():
+ downsampled_masks = F.interpolate(
+ gt_masks.unsqueeze(0), (mask_h, mask_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ downsampled_masks = downsampled_masks.gt(0.5).float()
+ segm_targets = torch.zeros_like(segm_pred, requires_grad=False)
+ for obj_idx in range(downsampled_masks.size(0)):
+ segm_targets[gt_labels[obj_idx] - 1] = torch.max(
+ segm_targets[gt_labels[obj_idx] - 1],
+ downsampled_masks[obj_idx])
+ return segm_targets
+
+
+@HEADS.register_module()
+class YOLACTProtonet(nn.Module):
+ """YOLACT mask head used in https://arxiv.org/abs/1904.02689.
+
+ This head outputs the mask prototypes for YOLACT.
+
+ Args:
+ in_channels (int): Number of channels in the input feature map.
+ proto_channels (tuple[int]): Output channels of protonet convs.
+ proto_kernel_sizes (tuple[int]): Kernel sizes of protonet convs.
+ include_last_relu (Bool): If keep the last relu of protonet.
+ num_protos (int): Number of prototypes.
+ num_classes (int): Number of categories excluding the background
+ category.
+ loss_mask_weight (float): Reweight the mask loss by this factor.
+ max_masks_to_train (int): Maximum number of masks to train for
+ each image.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels=256,
+ proto_channels=(256, 256, 256, None, 256, 32),
+ proto_kernel_sizes=(3, 3, 3, -2, 3, 1),
+ include_last_relu=True,
+ num_protos=32,
+ loss_mask_weight=1.0,
+ max_masks_to_train=100):
+ super(YOLACTProtonet, self).__init__()
+ self.in_channels = in_channels
+ self.proto_channels = proto_channels
+ self.proto_kernel_sizes = proto_kernel_sizes
+ self.include_last_relu = include_last_relu
+ self.protonet = self._init_layers()
+
+ self.loss_mask_weight = loss_mask_weight
+ self.num_protos = num_protos
+ self.num_classes = num_classes
+ self.max_masks_to_train = max_masks_to_train
+ self.fp16_enabled = False
+
+ def _init_layers(self):
+ """A helper function to take a config setting and turn it into a
+ network."""
+ # Possible patterns:
+ # ( 256, 3) -> conv
+ # ( 256,-2) -> deconv
+ # (None,-2) -> bilinear interpolate
+ in_channels = self.in_channels
+ protonets = nn.ModuleList()
+ for num_channels, kernel_size in zip(self.proto_channels,
+ self.proto_kernel_sizes):
+ if kernel_size > 0:
+ layer = nn.Conv2d(
+ in_channels,
+ num_channels,
+ kernel_size,
+ padding=kernel_size // 2)
+ else:
+ if num_channels is None:
+ layer = InterpolateModule(
+ scale_factor=-kernel_size,
+ mode='bilinear',
+ align_corners=False)
+ else:
+ layer = nn.ConvTranspose2d(
+ in_channels,
+ num_channels,
+ -kernel_size,
+ padding=kernel_size // 2)
+ protonets.append(layer)
+ protonets.append(nn.ReLU(inplace=True))
+ in_channels = num_channels if num_channels is not None \
+ else in_channels
+ if not self.include_last_relu:
+ protonets = protonets[:-1]
+ return nn.Sequential(*protonets)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.protonet:
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, x, coeff_pred, bboxes, img_meta, sampling_results=None):
+ """Forward feature from the upstream network to get prototypes and
+ linearly combine the prototypes, using masks coefficients, into
+ instance masks. Finally, crop the instance masks with given bboxes.
+
+ Args:
+ x (Tensor): Feature from the upstream network, which is
+ a 4D-tensor.
+ coeff_pred (list[Tensor]): Mask coefficients for each scale
+ level with shape (N, num_anchors * num_protos, H, W).
+ bboxes (list[Tensor]): Box used for cropping with shape
+ (N, num_anchors * 4, H, W). During training, they are
+ ground truth boxes. During testing, they are predicted
+ boxes.
+ img_meta (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ sampling_results (List[:obj:``SamplingResult``]): Sampler results
+ for each image.
+
+ Returns:
+ list[Tensor]: Predicted instance segmentation masks.
+ """
+ prototypes = self.protonet(x)
+ prototypes = prototypes.permute(0, 2, 3, 1).contiguous()
+
+ num_imgs = x.size(0)
+ # Training state
+ if self.training:
+ coeff_pred_list = []
+ for coeff_pred_per_level in coeff_pred:
+ coeff_pred_per_level = \
+ coeff_pred_per_level.permute(0, 2, 3, 1)\
+ .reshape(num_imgs, -1, self.num_protos)
+ coeff_pred_list.append(coeff_pred_per_level)
+ coeff_pred = torch.cat(coeff_pred_list, dim=1)
+
+ mask_pred_list = []
+ for idx in range(num_imgs):
+ cur_prototypes = prototypes[idx]
+ cur_coeff_pred = coeff_pred[idx]
+ cur_bboxes = bboxes[idx]
+ cur_img_meta = img_meta[idx]
+
+ # Testing state
+ if not self.training:
+ bboxes_for_cropping = cur_bboxes
+ else:
+ cur_sampling_results = sampling_results[idx]
+ pos_assigned_gt_inds = \
+ cur_sampling_results.pos_assigned_gt_inds
+ bboxes_for_cropping = cur_bboxes[pos_assigned_gt_inds].clone()
+ pos_inds = cur_sampling_results.pos_inds
+ cur_coeff_pred = cur_coeff_pred[pos_inds]
+
+ # Linearly combine the prototypes with the mask coefficients
+ mask_pred = cur_prototypes @ cur_coeff_pred.t()
+ mask_pred = torch.sigmoid(mask_pred)
+
+ h, w = cur_img_meta['img_shape'][:2]
+ bboxes_for_cropping[:, 0] /= w
+ bboxes_for_cropping[:, 1] /= h
+ bboxes_for_cropping[:, 2] /= w
+ bboxes_for_cropping[:, 3] /= h
+
+ mask_pred = self.crop(mask_pred, bboxes_for_cropping)
+ mask_pred = mask_pred.permute(2, 0, 1).contiguous()
+ mask_pred_list.append(mask_pred)
+ return mask_pred_list
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, gt_masks, gt_bboxes, img_meta, sampling_results):
+ """Compute loss of the head.
+
+ Args:
+ mask_pred (list[Tensor]): Predicted prototypes with shape
+ (num_classes, H, W).
+ gt_masks (list[Tensor]): Ground truth masks for each image with
+ the same shape of the input image.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ img_meta (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ sampling_results (List[:obj:``SamplingResult``]): Sampler results
+ for each image.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ loss_mask = []
+ num_imgs = len(mask_pred)
+ total_pos = 0
+ for idx in range(num_imgs):
+ cur_mask_pred = mask_pred[idx]
+ cur_gt_masks = gt_masks[idx].float()
+ cur_gt_bboxes = gt_bboxes[idx]
+ cur_img_meta = img_meta[idx]
+ cur_sampling_results = sampling_results[idx]
+
+ pos_assigned_gt_inds = cur_sampling_results.pos_assigned_gt_inds
+ num_pos = pos_assigned_gt_inds.size(0)
+ # Since we're producing (near) full image masks,
+ # it'd take too much vram to backprop on every single mask.
+ # Thus we select only a subset.
+ if num_pos > self.max_masks_to_train:
+ perm = torch.randperm(num_pos)
+ select = perm[:self.max_masks_to_train]
+ cur_mask_pred = cur_mask_pred[select]
+ pos_assigned_gt_inds = pos_assigned_gt_inds[select]
+ num_pos = self.max_masks_to_train
+ total_pos += num_pos
+
+ gt_bboxes_for_reweight = cur_gt_bboxes[pos_assigned_gt_inds]
+
+ mask_targets = self.get_targets(cur_mask_pred, cur_gt_masks,
+ pos_assigned_gt_inds)
+ if num_pos == 0:
+ loss = cur_mask_pred.sum() * 0.
+ elif mask_targets is None:
+ loss = F.binary_cross_entropy(cur_mask_pred,
+ torch.zeros_like(cur_mask_pred),
+ torch.zeros_like(cur_mask_pred))
+ else:
+ cur_mask_pred = torch.clamp(cur_mask_pred, 0, 1)
+ loss = F.binary_cross_entropy(
+ cur_mask_pred, mask_targets,
+ reduction='none') * self.loss_mask_weight
+
+ h, w = cur_img_meta['img_shape'][:2]
+ gt_bboxes_width = (gt_bboxes_for_reweight[:, 2] -
+ gt_bboxes_for_reweight[:, 0]) / w
+ gt_bboxes_height = (gt_bboxes_for_reweight[:, 3] -
+ gt_bboxes_for_reweight[:, 1]) / h
+ loss = loss.mean(dim=(1,
+ 2)) / gt_bboxes_width / gt_bboxes_height
+ loss = torch.sum(loss)
+ loss_mask.append(loss)
+
+ if total_pos == 0:
+ total_pos += 1 # avoid nan
+ loss_mask = [x / total_pos for x in loss_mask]
+
+ return dict(loss_mask=loss_mask)
+
+ def get_targets(self, mask_pred, gt_masks, pos_assigned_gt_inds):
+ """Compute instance segmentation targets for each image.
+
+ Args:
+ mask_pred (Tensor): Predicted prototypes with shape
+ (num_classes, H, W).
+ gt_masks (Tensor): Ground truth masks for each image with
+ the same shape of the input image.
+ pos_assigned_gt_inds (Tensor): GT indices of the corresponding
+ positive samples.
+ Returns:
+ Tensor: Instance segmentation targets with shape
+ (num_instances, H, W).
+ """
+ if gt_masks.size(0) == 0:
+ return None
+ mask_h, mask_w = mask_pred.shape[-2:]
+ gt_masks = F.interpolate(
+ gt_masks.unsqueeze(0), (mask_h, mask_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0)
+ gt_masks = gt_masks.gt(0.5).float()
+ mask_targets = gt_masks[pos_assigned_gt_inds]
+ return mask_targets
+
+ def get_seg_masks(self, mask_pred, label_pred, img_meta, rescale):
+ """Resize, binarize, and format the instance mask predictions.
+
+ Args:
+ mask_pred (Tensor): shape (N, H, W).
+ label_pred (Tensor): shape (N, ).
+ img_meta (dict): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If rescale is False, then returned masks will
+ fit the scale of imgs[0].
+ Returns:
+ list[ndarray]: Mask predictions grouped by their predicted classes.
+ """
+ ori_shape = img_meta['ori_shape']
+ scale_factor = img_meta['scale_factor']
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ img_h = np.round(ori_shape[0] * scale_factor[1]).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor[0]).astype(np.int32)
+
+ cls_segms = [[] for _ in range(self.num_classes)]
+ if mask_pred.size(0) == 0:
+ return cls_segms
+
+ mask_pred = F.interpolate(
+ mask_pred.unsqueeze(0), (img_h, img_w),
+ mode='bilinear',
+ align_corners=False).squeeze(0) > 0.5
+ mask_pred = mask_pred.cpu().numpy().astype(np.uint8)
+
+ for m, l in zip(mask_pred, label_pred):
+ cls_segms[l].append(m)
+ return cls_segms
+
+ def crop(self, masks, boxes, padding=1):
+ """Crop predicted masks by zeroing out everything not in the predicted
+ bbox.
+
+ Args:
+ masks (Tensor): shape [H, W, N].
+ boxes (Tensor): bbox coords in relative point form with
+ shape [N, 4].
+
+ Return:
+ Tensor: The cropped masks.
+ """
+ h, w, n = masks.size()
+ x1, x2 = self.sanitize_coordinates(
+ boxes[:, 0], boxes[:, 2], w, padding, cast=False)
+ y1, y2 = self.sanitize_coordinates(
+ boxes[:, 1], boxes[:, 3], h, padding, cast=False)
+
+ rows = torch.arange(
+ w, device=masks.device, dtype=x1.dtype).view(1, -1,
+ 1).expand(h, w, n)
+ cols = torch.arange(
+ h, device=masks.device, dtype=x1.dtype).view(-1, 1,
+ 1).expand(h, w, n)
+
+ masks_left = rows >= x1.view(1, 1, -1)
+ masks_right = rows < x2.view(1, 1, -1)
+ masks_up = cols >= y1.view(1, 1, -1)
+ masks_down = cols < y2.view(1, 1, -1)
+
+ crop_mask = masks_left * masks_right * masks_up * masks_down
+
+ return masks * crop_mask.float()
+
+ def sanitize_coordinates(self, x1, x2, img_size, padding=0, cast=True):
+ """Sanitizes the input coordinates so that x1 < x2, x1 != x2, x1 >= 0,
+ and x2 <= image_size. Also converts from relative to absolute
+ coordinates and casts the results to long tensors.
+
+ Warning: this does things in-place behind the scenes so
+ copy if necessary.
+
+ Args:
+ _x1 (Tensor): shape (N, ).
+ _x2 (Tensor): shape (N, ).
+ img_size (int): Size of the input image.
+ padding (int): x1 >= padding, x2 <= image_size-padding.
+ cast (bool): If cast is false, the result won't be cast to longs.
+
+ Returns:
+ tuple:
+ x1 (Tensor): Sanitized _x1.
+ x2 (Tensor): Sanitized _x2.
+ """
+ x1 = x1 * img_size
+ x2 = x2 * img_size
+ if cast:
+ x1 = x1.long()
+ x2 = x2.long()
+ x1 = torch.min(x1, x2)
+ x2 = torch.max(x1, x2)
+ x1 = torch.clamp(x1 - padding, min=0)
+ x2 = torch.clamp(x2 + padding, max=img_size)
+ return x1, x2
+
+
+class InterpolateModule(nn.Module):
+ """This is a module version of F.interpolate.
+
+ Any arguments you give it just get passed along for the ride.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+
+ self.args = args
+ self.kwargs = kwargs
+
+ def forward(self, x):
+ """Forward features from the upstream network."""
+ return F.interpolate(x, *self.args, **self.kwargs)
diff --git a/mmdet/models/dense_heads/yolo_head.py b/mmdet/models/dense_heads/yolo_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a005d36903333f37a6c6d31b4d613c071f4a07
--- /dev/null
+++ b/mmdet/models/dense_heads/yolo_head.py
@@ -0,0 +1,577 @@
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import warnings
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, normal_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import (build_anchor_generator, build_assigner,
+ build_bbox_coder, build_sampler, images_to_levels,
+ multi_apply, multiclass_nms)
+from ..builder import HEADS, build_loss
+from .base_dense_head import BaseDenseHead
+from .dense_test_mixins import BBoxTestMixin
+
+
+@HEADS.register_module()
+class YOLOV3Head(BaseDenseHead, BBoxTestMixin):
+ """YOLOV3Head Paper link: https://arxiv.org/abs/1804.02767.
+
+ Args:
+ num_classes (int): The number of object classes (w/o background)
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (List[int]): The number of output channels per scale
+ before the final 1x1 layer. Default: (1024, 512, 256).
+ anchor_generator (dict): Config dict for anchor generator
+ bbox_coder (dict): Config of bounding box coder.
+ featmap_strides (List[int]): The stride of each scale.
+ Should be in descending order. Default: (32, 16, 8).
+ one_hot_smoother (float): Set a non-zero value to enable label-smooth
+ Default: 0.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ loss_cls (dict): Config of classification loss.
+ loss_conf (dict): Config of confidence loss.
+ loss_xy (dict): Config of xy coordinate loss.
+ loss_wh (dict): Config of wh coordinate loss.
+ train_cfg (dict): Training config of YOLOV3 head. Default: None.
+ test_cfg (dict): Testing config of YOLOV3 head. Default: None.
+ """
+
+ def __init__(self,
+ num_classes,
+ in_channels,
+ out_channels=(1024, 512, 256),
+ anchor_generator=dict(
+ type='YOLOAnchorGenerator',
+ base_sizes=[[(116, 90), (156, 198), (373, 326)],
+ [(30, 61), (62, 45), (59, 119)],
+ [(10, 13), (16, 30), (33, 23)]],
+ strides=[32, 16, 8]),
+ bbox_coder=dict(type='YOLOBBoxCoder'),
+ featmap_strides=[32, 16, 8],
+ one_hot_smoother=0.,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_conf=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_xy=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_wh=dict(type='MSELoss', loss_weight=1.0),
+ train_cfg=None,
+ test_cfg=None):
+ super(YOLOV3Head, self).__init__()
+ # Check params
+ assert (len(in_channels) == len(out_channels) == len(featmap_strides))
+
+ self.num_classes = num_classes
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.featmap_strides = featmap_strides
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if self.train_cfg:
+ self.assigner = build_assigner(self.train_cfg.assigner)
+ if hasattr(self.train_cfg, 'sampler'):
+ sampler_cfg = self.train_cfg.sampler
+ else:
+ sampler_cfg = dict(type='PseudoSampler')
+ self.sampler = build_sampler(sampler_cfg, context=self)
+
+ self.one_hot_smoother = one_hot_smoother
+
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.anchor_generator = build_anchor_generator(anchor_generator)
+
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_conf = build_loss(loss_conf)
+ self.loss_xy = build_loss(loss_xy)
+ self.loss_wh = build_loss(loss_wh)
+ # usually the numbers of anchors for each level are the same
+ # except SSD detectors
+ self.num_anchors = self.anchor_generator.num_base_anchors[0]
+ assert len(
+ self.anchor_generator.num_base_anchors) == len(featmap_strides)
+ self._init_layers()
+
+ @property
+ def num_levels(self):
+ return len(self.featmap_strides)
+
+ @property
+ def num_attrib(self):
+ """int: number of attributes in pred_map, bboxes (4) +
+ objectness (1) + num_classes"""
+
+ return 5 + self.num_classes
+
+ def _init_layers(self):
+ self.convs_bridge = nn.ModuleList()
+ self.convs_pred = nn.ModuleList()
+ for i in range(self.num_levels):
+ conv_bridge = ConvModule(
+ self.in_channels[i],
+ self.out_channels[i],
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ conv_pred = nn.Conv2d(self.out_channels[i],
+ self.num_anchors * self.num_attrib, 1)
+
+ self.convs_bridge.append(conv_bridge)
+ self.convs_pred.append(conv_pred)
+
+ def init_weights(self):
+ """Initialize weights of the head."""
+ for m in self.convs_pred:
+ normal_init(m, std=0.01)
+
+ def forward(self, feats):
+ """Forward features from the upstream network.
+
+ Args:
+ feats (tuple[Tensor]): Features from the upstream network, each is
+ a 4D-tensor.
+
+ Returns:
+ tuple[Tensor]: A tuple of multi-level predication map, each is a
+ 4D-tensor of shape (batch_size, 5+num_classes, height, width).
+ """
+
+ assert len(feats) == self.num_levels
+ pred_maps = []
+ for i in range(self.num_levels):
+ x = feats[i]
+ x = self.convs_bridge[i](x)
+ pred_map = self.convs_pred[i](x)
+ pred_maps.append(pred_map)
+
+ return tuple(pred_maps),
+
+ @force_fp32(apply_to=('pred_maps', ))
+ def get_bboxes(self,
+ pred_maps,
+ img_metas,
+ cfg=None,
+ rescale=False,
+ with_nms=True):
+ """Transform network output for a batch into bbox predictions.
+
+ Args:
+ pred_maps (list[Tensor]): Raw predictions for a batch of images.
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used. Default: None.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ num_levels = len(pred_maps)
+ pred_maps_list = [pred_maps[i].detach() for i in range(num_levels)]
+ scale_factors = [
+ img_metas[i]['scale_factor']
+ for i in range(pred_maps_list[0].shape[0])
+ ]
+ result_list = self._get_bboxes(pred_maps_list, scale_factors, cfg,
+ rescale, with_nms)
+ return result_list
+
+ def _get_bboxes(self,
+ pred_maps_list,
+ scale_factors,
+ cfg,
+ rescale=False,
+ with_nms=True):
+ """Transform outputs for a single batch item into bbox predictions.
+
+ Args:
+ pred_maps_list (list[Tensor]): Prediction maps for different scales
+ of each single image in the batch.
+ scale_factors (list(ndarray)): Scale factor of the image arrange as
+ (w_scale, h_scale, w_scale, h_scale).
+ cfg (mmcv.Config | None): Test / postprocessing configuration,
+ if None, test_cfg would be used.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ with_nms (bool): If True, do nms before return boxes.
+ Default: True.
+
+ Returns:
+ list[tuple[Tensor, Tensor]]: Each item in result_list is 2-tuple.
+ The first item is an (n, 5) tensor, where 5 represent
+ (tl_x, tl_y, br_x, br_y, score) and the score between 0 and 1.
+ The shape of the second tensor in the tuple is (n,), and
+ each element represents the class label of the corresponding
+ box.
+ """
+ cfg = self.test_cfg if cfg is None else cfg
+ assert len(pred_maps_list) == self.num_levels
+
+ device = pred_maps_list[0].device
+ batch_size = pred_maps_list[0].shape[0]
+
+ featmap_sizes = [
+ pred_maps_list[i].shape[-2:] for i in range(self.num_levels)
+ ]
+ multi_lvl_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device)
+ # convert to tensor to keep tracing
+ nms_pre_tensor = torch.tensor(
+ cfg.get('nms_pre', -1), device=device, dtype=torch.long)
+
+ multi_lvl_bboxes = []
+ multi_lvl_cls_scores = []
+ multi_lvl_conf_scores = []
+ for i in range(self.num_levels):
+ # get some key info for current scale
+ pred_map = pred_maps_list[i]
+ stride = self.featmap_strides[i]
+ # (b,h, w, num_anchors*num_attrib) ->
+ # (b,h*w*num_anchors, num_attrib)
+ pred_map = pred_map.permute(0, 2, 3,
+ 1).reshape(batch_size, -1,
+ self.num_attrib)
+ # Inplace operation like
+ # ```pred_map[..., :2] = \torch.sigmoid(pred_map[..., :2])```
+ # would create constant tensor when exporting to onnx
+ pred_map_conf = torch.sigmoid(pred_map[..., :2])
+ pred_map_rest = pred_map[..., 2:]
+ pred_map = torch.cat([pred_map_conf, pred_map_rest], dim=-1)
+ pred_map_boxes = pred_map[..., :4]
+ multi_lvl_anchor = multi_lvl_anchors[i]
+ multi_lvl_anchor = multi_lvl_anchor.expand_as(pred_map_boxes)
+ bbox_pred = self.bbox_coder.decode(multi_lvl_anchor,
+ pred_map_boxes, stride)
+ # conf and cls
+ conf_pred = torch.sigmoid(pred_map[..., 4])
+ cls_pred = torch.sigmoid(pred_map[..., 5:]).view(
+ batch_size, -1, self.num_classes) # Cls pred one-hot.
+
+ # Get top-k prediction
+ # Always keep topk op for dynamic input in onnx
+ if nms_pre_tensor > 0 and (torch.onnx.is_in_onnx_export()
+ or conf_pred.shape[1] > nms_pre_tensor):
+ from torch import _shape_as_tensor
+ # keep shape as tensor and get k
+ num_anchor = _shape_as_tensor(conf_pred)[1].to(device)
+ nms_pre = torch.where(nms_pre_tensor < num_anchor,
+ nms_pre_tensor, num_anchor)
+ _, topk_inds = conf_pred.topk(nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ bbox_pred = bbox_pred[batch_inds, topk_inds, :]
+ cls_pred = cls_pred[batch_inds, topk_inds, :]
+ conf_pred = conf_pred[batch_inds, topk_inds]
+
+ # Save the result of current scale
+ multi_lvl_bboxes.append(bbox_pred)
+ multi_lvl_cls_scores.append(cls_pred)
+ multi_lvl_conf_scores.append(conf_pred)
+
+ # Merge the results of different scales together
+ batch_mlvl_bboxes = torch.cat(multi_lvl_bboxes, dim=1)
+ batch_mlvl_scores = torch.cat(multi_lvl_cls_scores, dim=1)
+ batch_mlvl_conf_scores = torch.cat(multi_lvl_conf_scores, dim=1)
+
+ # Set max number of box to be feed into nms in deployment
+ deploy_nms_pre = cfg.get('deploy_nms_pre', -1)
+ if deploy_nms_pre > 0 and torch.onnx.is_in_onnx_export():
+ _, topk_inds = batch_mlvl_conf_scores.topk(deploy_nms_pre)
+ batch_inds = torch.arange(batch_size).view(
+ -1, 1).expand_as(topk_inds).long()
+ batch_mlvl_bboxes = batch_mlvl_bboxes[batch_inds, topk_inds, :]
+ batch_mlvl_scores = batch_mlvl_scores[batch_inds, topk_inds, :]
+ batch_mlvl_conf_scores = batch_mlvl_conf_scores[batch_inds,
+ topk_inds]
+
+ if with_nms and (batch_mlvl_conf_scores.size(0) == 0):
+ return torch.zeros((0, 5)), torch.zeros((0, ))
+
+ if rescale:
+ batch_mlvl_bboxes /= batch_mlvl_bboxes.new_tensor(
+ scale_factors).unsqueeze(1)
+
+ # In mmdet 2.x, the class_id for background is num_classes.
+ # i.e., the last column.
+ padding = batch_mlvl_scores.new_zeros(batch_size,
+ batch_mlvl_scores.shape[1], 1)
+ batch_mlvl_scores = torch.cat([batch_mlvl_scores, padding], dim=-1)
+
+ # Support exporting to onnx without nms
+ if with_nms and cfg.get('nms', None) is not None:
+ det_results = []
+ for (mlvl_bboxes, mlvl_scores,
+ mlvl_conf_scores) in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_conf_scores):
+ # Filtering out all predictions with conf < conf_thr
+ conf_thr = cfg.get('conf_thr', -1)
+ if conf_thr > 0 and (not torch.onnx.is_in_onnx_export()):
+ # TensorRT not support NonZero
+ # add as_tuple=False for compatibility in Pytorch 1.6
+ # flatten would create a Reshape op with constant values,
+ # and raise RuntimeError when doing inference in ONNX
+ # Runtime with a different input image (#4221).
+ conf_inds = mlvl_conf_scores.ge(conf_thr).nonzero(
+ as_tuple=False).squeeze(1)
+ mlvl_bboxes = mlvl_bboxes[conf_inds, :]
+ mlvl_scores = mlvl_scores[conf_inds, :]
+ mlvl_conf_scores = mlvl_conf_scores[conf_inds]
+
+ det_bboxes, det_labels = multiclass_nms(
+ mlvl_bboxes,
+ mlvl_scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=mlvl_conf_scores)
+ det_results.append(tuple([det_bboxes, det_labels]))
+
+ else:
+ det_results = [
+ tuple(mlvl_bs)
+ for mlvl_bs in zip(batch_mlvl_bboxes, batch_mlvl_scores,
+ batch_mlvl_conf_scores)
+ ]
+ return det_results
+
+ @force_fp32(apply_to=('pred_maps', ))
+ def loss(self,
+ pred_maps,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ gt_bboxes_ignore=None):
+ """Compute loss of the head.
+
+ Args:
+ pred_maps (list[Tensor]): Prediction map for each scale level,
+ shape (N, num_anchors * num_attrib, H, W)
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ img_metas (list[dict]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ num_imgs = len(img_metas)
+ device = pred_maps[0][0].device
+
+ featmap_sizes = [
+ pred_maps[i].shape[-2:] for i in range(self.num_levels)
+ ]
+ multi_level_anchors = self.anchor_generator.grid_anchors(
+ featmap_sizes, device)
+ anchor_list = [multi_level_anchors for _ in range(num_imgs)]
+
+ responsible_flag_list = []
+ for img_id in range(len(img_metas)):
+ responsible_flag_list.append(
+ self.anchor_generator.responsible_flags(
+ featmap_sizes, gt_bboxes[img_id], device))
+
+ target_maps_list, neg_maps_list = self.get_targets(
+ anchor_list, responsible_flag_list, gt_bboxes, gt_labels)
+
+ losses_cls, losses_conf, losses_xy, losses_wh = multi_apply(
+ self.loss_single, pred_maps, target_maps_list, neg_maps_list)
+
+ return dict(
+ loss_cls=losses_cls,
+ loss_conf=losses_conf,
+ loss_xy=losses_xy,
+ loss_wh=losses_wh)
+
+ def loss_single(self, pred_map, target_map, neg_map):
+ """Compute loss of a single image from a batch.
+
+ Args:
+ pred_map (Tensor): Raw predictions for a single level.
+ target_map (Tensor): The Ground-Truth target for a single level.
+ neg_map (Tensor): The negative masks for a single level.
+
+ Returns:
+ tuple:
+ loss_cls (Tensor): Classification loss.
+ loss_conf (Tensor): Confidence loss.
+ loss_xy (Tensor): Regression loss of x, y coordinate.
+ loss_wh (Tensor): Regression loss of w, h coordinate.
+ """
+
+ num_imgs = len(pred_map)
+ pred_map = pred_map.permute(0, 2, 3,
+ 1).reshape(num_imgs, -1, self.num_attrib)
+ neg_mask = neg_map.float()
+ pos_mask = target_map[..., 4]
+ pos_and_neg_mask = neg_mask + pos_mask
+ pos_mask = pos_mask.unsqueeze(dim=-1)
+ if torch.max(pos_and_neg_mask) > 1.:
+ warnings.warn('There is overlap between pos and neg sample.')
+ pos_and_neg_mask = pos_and_neg_mask.clamp(min=0., max=1.)
+
+ pred_xy = pred_map[..., :2]
+ pred_wh = pred_map[..., 2:4]
+ pred_conf = pred_map[..., 4]
+ pred_label = pred_map[..., 5:]
+
+ target_xy = target_map[..., :2]
+ target_wh = target_map[..., 2:4]
+ target_conf = target_map[..., 4]
+ target_label = target_map[..., 5:]
+
+ loss_cls = self.loss_cls(pred_label, target_label, weight=pos_mask)
+ loss_conf = self.loss_conf(
+ pred_conf, target_conf, weight=pos_and_neg_mask)
+ loss_xy = self.loss_xy(pred_xy, target_xy, weight=pos_mask)
+ loss_wh = self.loss_wh(pred_wh, target_wh, weight=pos_mask)
+
+ return loss_cls, loss_conf, loss_xy, loss_wh
+
+ def get_targets(self, anchor_list, responsible_flag_list, gt_bboxes_list,
+ gt_labels_list):
+ """Compute target maps for anchors in multiple images.
+
+ Args:
+ anchor_list (list[list[Tensor]]): Multi level anchors of each
+ image. The outer list indicates images, and the inner list
+ corresponds to feature levels of the image. Each element of
+ the inner list is a tensor of shape (num_total_anchors, 4).
+ responsible_flag_list (list[list[Tensor]]): Multi level responsible
+ flags of each image. Each element is a tensor of shape
+ (num_total_anchors, )
+ gt_bboxes_list (list[Tensor]): Ground truth bboxes of each image.
+ gt_labels_list (list[Tensor]): Ground truth labels of each box.
+
+ Returns:
+ tuple: Usually returns a tuple containing learning targets.
+ - target_map_list (list[Tensor]): Target map of each level.
+ - neg_map_list (list[Tensor]): Negative map of each level.
+ """
+ num_imgs = len(anchor_list)
+
+ # anchor number of multi levels
+ num_level_anchors = [anchors.size(0) for anchors in anchor_list[0]]
+
+ results = multi_apply(self._get_targets_single, anchor_list,
+ responsible_flag_list, gt_bboxes_list,
+ gt_labels_list)
+
+ all_target_maps, all_neg_maps = results
+ assert num_imgs == len(all_target_maps) == len(all_neg_maps)
+ target_maps_list = images_to_levels(all_target_maps, num_level_anchors)
+ neg_maps_list = images_to_levels(all_neg_maps, num_level_anchors)
+
+ return target_maps_list, neg_maps_list
+
+ def _get_targets_single(self, anchors, responsible_flags, gt_bboxes,
+ gt_labels):
+ """Generate matching bounding box prior and converted GT.
+
+ Args:
+ anchors (list[Tensor]): Multi-level anchors of the image.
+ responsible_flags (list[Tensor]): Multi-level responsible flags of
+ anchors
+ gt_bboxes (Tensor): Ground truth bboxes of single image.
+ gt_labels (Tensor): Ground truth labels of single image.
+
+ Returns:
+ tuple:
+ target_map (Tensor): Predication target map of each
+ scale level, shape (num_total_anchors,
+ 5+num_classes)
+ neg_map (Tensor): Negative map of each scale level,
+ shape (num_total_anchors,)
+ """
+
+ anchor_strides = []
+ for i in range(len(anchors)):
+ anchor_strides.append(
+ torch.tensor(self.featmap_strides[i],
+ device=gt_bboxes.device).repeat(len(anchors[i])))
+ concat_anchors = torch.cat(anchors)
+ concat_responsible_flags = torch.cat(responsible_flags)
+
+ anchor_strides = torch.cat(anchor_strides)
+ assert len(anchor_strides) == len(concat_anchors) == \
+ len(concat_responsible_flags)
+ assign_result = self.assigner.assign(concat_anchors,
+ concat_responsible_flags,
+ gt_bboxes)
+ sampling_result = self.sampler.sample(assign_result, concat_anchors,
+ gt_bboxes)
+
+ target_map = concat_anchors.new_zeros(
+ concat_anchors.size(0), self.num_attrib)
+
+ target_map[sampling_result.pos_inds, :4] = self.bbox_coder.encode(
+ sampling_result.pos_bboxes, sampling_result.pos_gt_bboxes,
+ anchor_strides[sampling_result.pos_inds])
+
+ target_map[sampling_result.pos_inds, 4] = 1
+
+ gt_labels_one_hot = F.one_hot(
+ gt_labels, num_classes=self.num_classes).float()
+ if self.one_hot_smoother != 0: # label smooth
+ gt_labels_one_hot = gt_labels_one_hot * (
+ 1 - self.one_hot_smoother
+ ) + self.one_hot_smoother / self.num_classes
+ target_map[sampling_result.pos_inds, 5:] = gt_labels_one_hot[
+ sampling_result.pos_assigned_gt_inds]
+
+ neg_map = concat_anchors.new_zeros(
+ concat_anchors.size(0), dtype=torch.uint8)
+ neg_map[sampling_result.neg_inds] = 1
+
+ return target_map, neg_map
+
+ def aug_test(self, feats, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ feats (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains features for all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[ndarray]: bbox results of each class
+ """
+ return self.aug_test_bboxes(feats, img_metas, rescale=rescale)
diff --git a/mmdet/models/detectors/__init__.py b/mmdet/models/detectors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..04011130435cf9fdfadeb821919046b1bddab7d4
--- /dev/null
+++ b/mmdet/models/detectors/__init__.py
@@ -0,0 +1,40 @@
+from .atss import ATSS
+from .base import BaseDetector
+from .cascade_rcnn import CascadeRCNN
+from .cornernet import CornerNet
+from .detr import DETR
+from .fast_rcnn import FastRCNN
+from .faster_rcnn import FasterRCNN
+from .fcos import FCOS
+from .fovea import FOVEA
+from .fsaf import FSAF
+from .gfl import GFL
+from .grid_rcnn import GridRCNN
+from .htc import HybridTaskCascade
+from .kd_one_stage import KnowledgeDistillationSingleStageDetector
+from .mask_rcnn import MaskRCNN
+from .mask_scoring_rcnn import MaskScoringRCNN
+from .nasfcos import NASFCOS
+from .paa import PAA
+from .point_rend import PointRend
+from .reppoints_detector import RepPointsDetector
+from .retinanet import RetinaNet
+from .rpn import RPN
+from .scnet import SCNet
+from .single_stage import SingleStageDetector
+from .sparse_rcnn import SparseRCNN
+from .trident_faster_rcnn import TridentFasterRCNN
+from .two_stage import TwoStageDetector
+from .vfnet import VFNet
+from .yolact import YOLACT
+from .yolo import YOLOV3
+
+__all__ = [
+ 'ATSS', 'BaseDetector', 'SingleStageDetector',
+ 'KnowledgeDistillationSingleStageDetector', 'TwoStageDetector', 'RPN',
+ 'FastRCNN', 'FasterRCNN', 'MaskRCNN', 'CascadeRCNN', 'HybridTaskCascade',
+ 'RetinaNet', 'FCOS', 'GridRCNN', 'MaskScoringRCNN', 'RepPointsDetector',
+ 'FOVEA', 'FSAF', 'NASFCOS', 'PointRend', 'GFL', 'CornerNet', 'PAA',
+ 'YOLOV3', 'YOLACT', 'VFNet', 'DETR', 'TridentFasterRCNN', 'SparseRCNN',
+ 'SCNet'
+]
diff --git a/mmdet/models/detectors/atss.py b/mmdet/models/detectors/atss.py
new file mode 100644
index 0000000000000000000000000000000000000000..db7139c6b4fcd7e83007cdb785520743ddae7066
--- /dev/null
+++ b/mmdet/models/detectors/atss.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class ATSS(SingleStageDetector):
+ """Implementation of `ATSS `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(ATSS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/base.py b/mmdet/models/detectors/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..89134f3696ead442a5ff57184e9d256fdf7d0ba4
--- /dev/null
+++ b/mmdet/models/detectors/base.py
@@ -0,0 +1,355 @@
+from abc import ABCMeta, abstractmethod
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from mmcv.runner import auto_fp16
+from mmcv.utils import print_log
+
+from mmdet.core.visualization import imshow_det_bboxes
+from mmdet.utils import get_root_logger
+
+
+class BaseDetector(nn.Module, metaclass=ABCMeta):
+ """Base class for detectors."""
+
+ def __init__(self):
+ super(BaseDetector, self).__init__()
+ self.fp16_enabled = False
+
+ @property
+ def with_neck(self):
+ """bool: whether the detector has a neck"""
+ return hasattr(self, 'neck') and self.neck is not None
+
+ # TODO: these properties need to be carefully handled
+ # for both single stage & two stage detectors
+ @property
+ def with_shared_head(self):
+ """bool: whether the detector has a shared head in the RoI Head"""
+ return hasattr(self, 'roi_head') and self.roi_head.with_shared_head
+
+ @property
+ def with_bbox(self):
+ """bool: whether the detector has a bbox head"""
+ return ((hasattr(self, 'roi_head') and self.roi_head.with_bbox)
+ or (hasattr(self, 'bbox_head') and self.bbox_head is not None))
+
+ @property
+ def with_mask(self):
+ """bool: whether the detector has a mask head"""
+ return ((hasattr(self, 'roi_head') and self.roi_head.with_mask)
+ or (hasattr(self, 'mask_head') and self.mask_head is not None))
+
+ @abstractmethod
+ def extract_feat(self, imgs):
+ """Extract features from images."""
+ pass
+
+ def extract_feats(self, imgs):
+ """Extract features from multiple images.
+
+ Args:
+ imgs (list[torch.Tensor]): A list of images. The images are
+ augmented from the same image but in different ways.
+
+ Returns:
+ list[torch.Tensor]: Features of different images
+ """
+ assert isinstance(imgs, list)
+ return [self.extract_feat(img) for img in imgs]
+
+ def forward_train(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ img (list[Tensor]): List of tensors of shape (1, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys, see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ kwargs (keyword arguments): Specific to concrete implementation.
+ """
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ batch_input_shape = tuple(imgs[0].size()[-2:])
+ for img_meta in img_metas:
+ img_meta['batch_input_shape'] = batch_input_shape
+
+ async def async_simple_test(self, img, img_metas, **kwargs):
+ raise NotImplementedError
+
+ @abstractmethod
+ def simple_test(self, img, img_metas, **kwargs):
+ pass
+
+ @abstractmethod
+ def aug_test(self, imgs, img_metas, **kwargs):
+ """Test function with test time augmentation."""
+ pass
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in detector.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if pretrained is not None:
+ logger = get_root_logger()
+ print_log(f'load model from: {pretrained}', logger=logger)
+
+ async def aforward_test(self, *, img, img_metas, **kwargs):
+ for var, name in [(img, 'img'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(img)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(img)}) '
+ f'!= num of image metas ({len(img_metas)})')
+ # TODO: remove the restriction of samples_per_gpu == 1 when prepared
+ samples_per_gpu = img[0].size(0)
+ assert samples_per_gpu == 1
+
+ if num_augs == 1:
+ return await self.async_simple_test(img[0], img_metas[0], **kwargs)
+ else:
+ raise NotImplementedError
+
+ def forward_test(self, imgs, img_metas, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) '
+ f'!= num of image meta ({len(img_metas)})')
+
+ # NOTE the batched image size information may be useful, e.g.
+ # in DETR, this is needed for the construction of masks, which is
+ # then used for the transformer_head.
+ for img, img_meta in zip(imgs, img_metas):
+ batch_size = len(img_meta)
+ for img_id in range(batch_size):
+ img_meta[img_id]['batch_input_shape'] = tuple(img.size()[-2:])
+
+ if num_augs == 1:
+ # proposals (List[List[Tensor]]): the outer list indicates
+ # test-time augs (multiscale, flip, etc.) and the inner list
+ # indicates images in a batch.
+ # The Tensor should have a shape Px4, where P is the number of
+ # proposals.
+ if 'proposals' in kwargs:
+ kwargs['proposals'] = kwargs['proposals'][0]
+ return self.simple_test(imgs[0], img_metas[0], **kwargs)
+ else:
+ assert imgs[0].size(0) == 1, 'aug test does not support ' \
+ 'inference with batch size ' \
+ f'{imgs[0].size(0)}'
+ # TODO: support test augmentation for predefined proposals
+ assert 'proposals' not in kwargs
+ return self.aug_test(imgs, img_metas, **kwargs)
+
+ @auto_fp16(apply_to=('img', ))
+ def forward(self, img, img_metas, return_loss=True, **kwargs):
+ """Calls either :func:`forward_train` or :func:`forward_test` depending
+ on whether ``return_loss`` is ``True``.
+
+ Note this setting will change the expected inputs. When
+ ``return_loss=True``, img and img_meta are single-nested (i.e. Tensor
+ and List[dict]), and when ``resturn_loss=False``, img and img_meta
+ should be double nested (i.e. List[Tensor], List[List[dict]]), with
+ the outer list indicating test time augmentations.
+ """
+ if return_loss:
+ return self.forward_train(img, img_metas, **kwargs)
+ else:
+ return self.forward_test(img, img_metas, **kwargs)
+
+ def _parse_losses(self, losses):
+ """Parse the raw outputs (losses) of the network.
+
+ Args:
+ losses (dict): Raw output of the network, which usually contain
+ losses and other necessary infomation.
+
+ Returns:
+ tuple[Tensor, dict]: (loss, log_vars), loss is the loss tensor \
+ which may be a weighted sum of all losses, log_vars contains \
+ all the variables to be sent to the logger.
+ """
+ log_vars = OrderedDict()
+ for loss_name, loss_value in losses.items():
+ if isinstance(loss_value, torch.Tensor):
+ log_vars[loss_name] = loss_value.mean()
+ elif isinstance(loss_value, list):
+ log_vars[loss_name] = sum(_loss.mean() for _loss in loss_value)
+ else:
+ raise TypeError(
+ f'{loss_name} is not a tensor or list of tensors')
+
+ loss = sum(_value for _key, _value in log_vars.items()
+ if 'loss' in _key)
+
+ log_vars['loss'] = loss
+ for loss_name, loss_value in log_vars.items():
+ # reduce loss when distributed training
+ if dist.is_available() and dist.is_initialized():
+ loss_value = loss_value.data.clone()
+ dist.all_reduce(loss_value.div_(dist.get_world_size()))
+ log_vars[loss_name] = loss_value.item()
+
+ return loss, log_vars
+
+ def train_step(self, data, optimizer):
+ """The iteration step during training.
+
+ This method defines an iteration step during training, except for the
+ back propagation and optimizer updating, which are done in an optimizer
+ hook. Note that in some complicated cases or models, the whole process
+ including back propagation and optimizer updating is also defined in
+ this method, such as GAN.
+
+ Args:
+ data (dict): The output of dataloader.
+ optimizer (:obj:`torch.optim.Optimizer` | dict): The optimizer of
+ runner is passed to ``train_step()``. This argument is unused
+ and reserved.
+
+ Returns:
+ dict: It should contain at least 3 keys: ``loss``, ``log_vars``, \
+ ``num_samples``.
+
+ - ``loss`` is a tensor for back propagation, which can be a \
+ weighted sum of multiple losses.
+ - ``log_vars`` contains all the variables to be sent to the
+ logger.
+ - ``num_samples`` indicates the batch size (when the model is \
+ DDP, it means the batch size on each GPU), which is used for \
+ averaging the logs.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def val_step(self, data, optimizer):
+ """The iteration step during validation.
+
+ This method shares the same signature as :func:`train_step`, but used
+ during val epochs. Note that the evaluation after training epochs is
+ not implemented with this method, but an evaluation hook.
+ """
+ losses = self(**data)
+ loss, log_vars = self._parse_losses(losses)
+
+ outputs = dict(
+ loss=loss, log_vars=log_vars, num_samples=len(data['img_metas']))
+
+ return outputs
+
+ def show_result(self,
+ img,
+ result,
+ score_thr=0.3,
+ bbox_color=(72, 101, 241),
+ text_color=(72, 101, 241),
+ mask_color=None,
+ thickness=2,
+ font_size=13,
+ win_name='',
+ show=False,
+ wait_time=0,
+ out_file=None):
+ """Draw `result` over `img`.
+
+ Args:
+ img (str or Tensor): The image to be displayed.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+ score_thr (float, optional): Minimum score of bboxes to be shown.
+ Default: 0.3.
+ bbox_color (str or tuple(int) or :obj:`Color`):Color of bbox lines.
+ The tuple of color should be in BGR order. Default: 'green'
+ text_color (str or tuple(int) or :obj:`Color`):Color of texts.
+ The tuple of color should be in BGR order. Default: 'green'
+ mask_color (None or str or tuple(int) or :obj:`Color`):
+ Color of masks. The tuple of color should be in BGR order.
+ Default: None
+ thickness (int): Thickness of lines. Default: 2
+ font_size (int): Font size of texts. Default: 13
+ win_name (str): The window name. Default: ''
+ wait_time (float): Value of waitKey param.
+ Default: 0.
+ show (bool): Whether to show the image.
+ Default: False.
+ out_file (str or None): The filename to write the image.
+ Default: None.
+
+ Returns:
+ img (Tensor): Only if not `show` or `out_file`
+ """
+ img = mmcv.imread(img)
+ img = img.copy()
+ if isinstance(result, tuple):
+ bbox_result, segm_result = result
+ if isinstance(segm_result, tuple):
+ segm_result = segm_result[0] # ms rcnn
+ else:
+ bbox_result, segm_result = result, None
+ bboxes = np.vstack(bbox_result)
+ labels = [
+ np.full(bbox.shape[0], i, dtype=np.int32)
+ for i, bbox in enumerate(bbox_result)
+ ]
+ labels = np.concatenate(labels)
+ # draw segmentation masks
+ segms = None
+ if segm_result is not None and len(labels) > 0: # non empty
+ segms = mmcv.concat_list(segm_result)
+ if isinstance(segms[0], torch.Tensor):
+ segms = torch.stack(segms, dim=0).detach().cpu().numpy()
+ else:
+ segms = np.stack(segms, axis=0)
+ # if out_file specified, do not show image in window
+ if out_file is not None:
+ show = False
+ # draw bounding boxes
+ img = imshow_det_bboxes(
+ img,
+ bboxes,
+ labels,
+ segms,
+ class_names=self.CLASSES,
+ score_thr=score_thr,
+ bbox_color=bbox_color,
+ text_color=text_color,
+ mask_color=mask_color,
+ thickness=thickness,
+ font_size=font_size,
+ win_name=win_name,
+ show=show,
+ wait_time=wait_time,
+ out_file=out_file)
+
+ if not (show or out_file):
+ return img
diff --git a/mmdet/models/detectors/cascade_rcnn.py b/mmdet/models/detectors/cascade_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d873dceb7e4efdf8d1e7d282badfe9b7118426b9
--- /dev/null
+++ b/mmdet/models/detectors/cascade_rcnn.py
@@ -0,0 +1,46 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class CascadeRCNN(TwoStageDetector):
+ r"""Implementation of `Cascade R-CNN: Delving into High Quality Object
+ Detection `_"""
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(CascadeRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+
+ def show_result(self, data, result, **kwargs):
+ """Show prediction results of the detector.
+
+ Args:
+ data (str or np.ndarray): Image filename or loaded image.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+
+ Returns:
+ np.ndarray: The image with bboxes drawn on it.
+ """
+ if self.with_mask:
+ ms_bbox_result, ms_segm_result = result
+ if isinstance(ms_bbox_result, dict):
+ result = (ms_bbox_result['ensemble'],
+ ms_segm_result['ensemble'])
+ else:
+ if isinstance(result, dict):
+ result = result['ensemble']
+ return super(CascadeRCNN, self).show_result(data, result, **kwargs)
diff --git a/mmdet/models/detectors/cornernet.py b/mmdet/models/detectors/cornernet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb8ccc1465ab66d1615ca16701a533a22b156295
--- /dev/null
+++ b/mmdet/models/detectors/cornernet.py
@@ -0,0 +1,95 @@
+import torch
+
+from mmdet.core import bbox2result, bbox_mapping_back
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class CornerNet(SingleStageDetector):
+ """CornerNet.
+
+ This detector is the implementation of the paper `CornerNet: Detecting
+ Objects as Paired Keypoints `_ .
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(CornerNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
+
+ def merge_aug_results(self, aug_results, img_metas):
+ """Merge augmented detection bboxes and score.
+
+ Args:
+ aug_results (list[list[Tensor]]): Det_bboxes and det_labels of each
+ image.
+ img_metas (list[list[dict]]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+
+ Returns:
+ tuple: (bboxes, labels)
+ """
+ recovered_bboxes, aug_labels = [], []
+ for bboxes_labels, img_info in zip(aug_results, img_metas):
+ img_shape = img_info[0]['img_shape'] # using shape before padding
+ scale_factor = img_info[0]['scale_factor']
+ flip = img_info[0]['flip']
+ bboxes, labels = bboxes_labels
+ bboxes, scores = bboxes[:, :4], bboxes[:, -1:]
+ bboxes = bbox_mapping_back(bboxes, img_shape, scale_factor, flip)
+ recovered_bboxes.append(torch.cat([bboxes, scores], dim=-1))
+ aug_labels.append(labels)
+
+ bboxes = torch.cat(recovered_bboxes, dim=0)
+ labels = torch.cat(aug_labels)
+
+ if bboxes.shape[0] > 0:
+ out_bboxes, out_labels = self.bbox_head._bboxes_nms(
+ bboxes, labels, self.bbox_head.test_cfg)
+ else:
+ out_bboxes, out_labels = bboxes, labels
+
+ return out_bboxes, out_labels
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Augment testing of CornerNet.
+
+ Args:
+ imgs (list[Tensor]): Augmented images.
+ img_metas (list[list[dict]]): Meta information of each image, e.g.,
+ image size, scaling factor, etc.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Note:
+ ``imgs`` must including flipped image pairs.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ img_inds = list(range(len(imgs)))
+
+ assert img_metas[0][0]['flip'] + img_metas[1][0]['flip'], (
+ 'aug test must have flipped image pair')
+ aug_results = []
+ for ind, flip_ind in zip(img_inds[0::2], img_inds[1::2]):
+ img_pair = torch.cat([imgs[ind], imgs[flip_ind]])
+ x = self.extract_feat(img_pair)
+ outs = self.bbox_head(x)
+ bbox_list = self.bbox_head.get_bboxes(
+ *outs, [img_metas[ind], img_metas[flip_ind]], False, False)
+ aug_results.append(bbox_list[0])
+ aug_results.append(bbox_list[1])
+
+ bboxes, labels = self.merge_aug_results(aug_results, img_metas)
+ bbox_results = bbox2result(bboxes, labels, self.bbox_head.num_classes)
+
+ return [bbox_results]
diff --git a/mmdet/models/detectors/detr.py b/mmdet/models/detectors/detr.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff82a280daa0a015f662bdf2509fa11542d46d4
--- /dev/null
+++ b/mmdet/models/detectors/detr.py
@@ -0,0 +1,46 @@
+from mmdet.core import bbox2result
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class DETR(SingleStageDetector):
+ r"""Implementation of `DETR: End-to-End Object Detection with
+ Transformers `_"""
+
+ def __init__(self,
+ backbone,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(DETR, self).__init__(backbone, None, bbox_head, train_cfg,
+ test_cfg, pretrained)
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ batch_size = len(img_metas)
+ assert batch_size == 1, 'Currently only batch_size 1 for inference ' \
+ f'mode is supported. Found batch_size {batch_size}.'
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x, img_metas)
+ bbox_list = self.bbox_head.get_bboxes(
+ *outs, img_metas, rescale=rescale)
+
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+ for det_bboxes, det_labels in bbox_list
+ ]
+ return bbox_results
diff --git a/mmdet/models/detectors/fast_rcnn.py b/mmdet/models/detectors/fast_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6e242767b927ed37198b6bc7862abecef99a33
--- /dev/null
+++ b/mmdet/models/detectors/fast_rcnn.py
@@ -0,0 +1,52 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class FastRCNN(TwoStageDetector):
+ """Implementation of `Fast R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(FastRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+
+ def forward_test(self, imgs, img_metas, proposals, **kwargs):
+ """
+ Args:
+ imgs (List[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (List[List[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch.
+ proposals (List[List[Tensor]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. The Tensor should have a shape Px4, where
+ P is the number of proposals.
+ """
+ for var, name in [(imgs, 'imgs'), (img_metas, 'img_metas')]:
+ if not isinstance(var, list):
+ raise TypeError(f'{name} must be a list, but got {type(var)}')
+
+ num_augs = len(imgs)
+ if num_augs != len(img_metas):
+ raise ValueError(f'num of augmentations ({len(imgs)}) '
+ f'!= num of image meta ({len(img_metas)})')
+
+ if num_augs == 1:
+ return self.simple_test(imgs[0], img_metas[0], proposals[0],
+ **kwargs)
+ else:
+ # TODO: support test-time augmentation
+ assert NotImplementedError
diff --git a/mmdet/models/detectors/faster_rcnn.py b/mmdet/models/detectors/faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..81bad0f43a48b1022c4cd996e26d6c90be93d4d0
--- /dev/null
+++ b/mmdet/models/detectors/faster_rcnn.py
@@ -0,0 +1,24 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class FasterRCNN(TwoStageDetector):
+ """Implementation of `Faster R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(FasterRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/fcos.py b/mmdet/models/detectors/fcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..58485c1864a11a66168b7597f345ea759ce20551
--- /dev/null
+++ b/mmdet/models/detectors/fcos.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FCOS(SingleStageDetector):
+ """Implementation of `FCOS `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(FCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/fovea.py b/mmdet/models/detectors/fovea.py
new file mode 100644
index 0000000000000000000000000000000000000000..22a578efffbd108db644d907bae95c7c8df31f2e
--- /dev/null
+++ b/mmdet/models/detectors/fovea.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FOVEA(SingleStageDetector):
+ """Implementation of `FoveaBox `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(FOVEA, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/fsaf.py b/mmdet/models/detectors/fsaf.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f10fa1ae10f31e6cb5de65505b14a4fc97dd022
--- /dev/null
+++ b/mmdet/models/detectors/fsaf.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class FSAF(SingleStageDetector):
+ """Implementation of `FSAF `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(FSAF, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/gfl.py b/mmdet/models/detectors/gfl.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d65cb2dfb7a56f57e08c3fcad67e1539e1e841
--- /dev/null
+++ b/mmdet/models/detectors/gfl.py
@@ -0,0 +1,16 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class GFL(SingleStageDetector):
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(GFL, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/grid_rcnn.py b/mmdet/models/detectors/grid_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6145a1464cd940bd4f98eaa15f6f9ecf6a10a20
--- /dev/null
+++ b/mmdet/models/detectors/grid_rcnn.py
@@ -0,0 +1,29 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class GridRCNN(TwoStageDetector):
+ """Grid R-CNN.
+
+ This detector is the implementation of:
+ - Grid R-CNN (https://arxiv.org/abs/1811.12030)
+ - Grid R-CNN Plus: Faster and Better (https://arxiv.org/abs/1906.05688)
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(GridRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/htc.py b/mmdet/models/detectors/htc.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9efdf420fa7373f7f1d116f8d97836d73b457bf
--- /dev/null
+++ b/mmdet/models/detectors/htc.py
@@ -0,0 +1,15 @@
+from ..builder import DETECTORS
+from .cascade_rcnn import CascadeRCNN
+
+
+@DETECTORS.register_module()
+class HybridTaskCascade(CascadeRCNN):
+ """Implementation of `HTC `_"""
+
+ def __init__(self, **kwargs):
+ super(HybridTaskCascade, self).__init__(**kwargs)
+
+ @property
+ def with_semantic(self):
+ """bool: whether the detector has a semantic head"""
+ return self.roi_head.with_semantic
diff --git a/mmdet/models/detectors/kd_one_stage.py b/mmdet/models/detectors/kd_one_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..671ec19015c87fefd065b84ae887147f90cc892b
--- /dev/null
+++ b/mmdet/models/detectors/kd_one_stage.py
@@ -0,0 +1,100 @@
+import mmcv
+import torch
+from mmcv.runner import load_checkpoint
+
+from .. import build_detector
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class KnowledgeDistillationSingleStageDetector(SingleStageDetector):
+ r"""Implementation of `Distilling the Knowledge in a Neural Network.
+ `_.
+
+ Args:
+ teacher_config (str | dict): Config file path
+ or the config object of teacher model.
+ teacher_ckpt (str, optional): Checkpoint path of teacher model.
+ If left as None, the model will not load any weights.
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ teacher_config,
+ teacher_ckpt=None,
+ eval_teacher=True,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super().__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained)
+ self.eval_teacher = eval_teacher
+ # Build teacher model
+ if isinstance(teacher_config, str):
+ teacher_config = mmcv.Config.fromfile(teacher_config)
+ self.teacher_model = build_detector(teacher_config['model'])
+ if teacher_ckpt is not None:
+ load_checkpoint(
+ self.teacher_model, teacher_ckpt, map_location='cpu')
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ x = self.extract_feat(img)
+ with torch.no_grad():
+ teacher_x = self.teacher_model.extract_feat(img)
+ out_teacher = self.teacher_model.bbox_head(teacher_x)
+ losses = self.bbox_head.forward_train(x, out_teacher, img_metas,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore)
+ return losses
+
+ def cuda(self, device=None):
+ """Since teacher_model is registered as a plain object, it is necessary
+ to put the teacher model to cuda when calling cuda function."""
+ self.teacher_model.cuda(device=device)
+ return super().cuda(device=device)
+
+ def train(self, mode=True):
+ """Set the same train mode for teacher and student model."""
+ if self.eval_teacher:
+ self.teacher_model.train(False)
+ else:
+ self.teacher_model.train(mode)
+ super().train(mode)
+
+ def __setattr__(self, name, value):
+ """Set attribute, i.e. self.name = value
+
+ This reloading prevent the teacher model from being registered as a
+ nn.Module. The teacher module is registered as a plain object, so that
+ the teacher parameters will not show up when calling
+ ``self.parameters``, ``self.modules``, ``self.children`` methods.
+ """
+ if name == 'teacher_model':
+ object.__setattr__(self, name, value)
+ else:
+ super().__setattr__(name, value)
diff --git a/mmdet/models/detectors/mask_rcnn.py b/mmdet/models/detectors/mask_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c15a7733170e059d2825138b3812319915b7cad6
--- /dev/null
+++ b/mmdet/models/detectors/mask_rcnn.py
@@ -0,0 +1,24 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class MaskRCNN(TwoStageDetector):
+ """Implementation of `Mask R-CNN `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(MaskRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/mask_scoring_rcnn.py b/mmdet/models/detectors/mask_scoring_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6252b6e1d234a201725342a5780fade7e21957c
--- /dev/null
+++ b/mmdet/models/detectors/mask_scoring_rcnn.py
@@ -0,0 +1,27 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class MaskScoringRCNN(TwoStageDetector):
+ """Mask Scoring RCNN.
+
+ https://arxiv.org/abs/1903.00241
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(MaskScoringRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/nasfcos.py b/mmdet/models/detectors/nasfcos.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb0148351546f45a451ef5f7a2a9ef4024e85b7c
--- /dev/null
+++ b/mmdet/models/detectors/nasfcos.py
@@ -0,0 +1,20 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class NASFCOS(SingleStageDetector):
+ """NAS-FCOS: Fast Neural Architecture Search for Object Detection.
+
+ https://arxiv.org/abs/1906.0442
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(NASFCOS, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/paa.py b/mmdet/models/detectors/paa.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b4bb5e0939b824d9fef7fc3bd49a0164c29613a
--- /dev/null
+++ b/mmdet/models/detectors/paa.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class PAA(SingleStageDetector):
+ """Implementation of `PAA `_."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(PAA, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/point_rend.py b/mmdet/models/detectors/point_rend.py
new file mode 100644
index 0000000000000000000000000000000000000000..808ef2258ae88301d349db3aaa2711f223e5c971
--- /dev/null
+++ b/mmdet/models/detectors/point_rend.py
@@ -0,0 +1,29 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class PointRend(TwoStageDetector):
+ """PointRend: Image Segmentation as Rendering
+
+ This detector is the implementation of
+ `PointRend `_.
+
+ """
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+ super(PointRend, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
diff --git a/mmdet/models/detectors/reppoints_detector.py b/mmdet/models/detectors/reppoints_detector.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5f6be31e14488e4b8a006b7142a82c872388d82
--- /dev/null
+++ b/mmdet/models/detectors/reppoints_detector.py
@@ -0,0 +1,22 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class RepPointsDetector(SingleStageDetector):
+ """RepPoints: Point Set Representation for Object Detection.
+
+ This detector is the implementation of:
+ - RepPoints detector (https://arxiv.org/pdf/1904.11490)
+ """
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(RepPointsDetector,
+ self).__init__(backbone, neck, bbox_head, train_cfg, test_cfg,
+ pretrained)
diff --git a/mmdet/models/detectors/retinanet.py b/mmdet/models/detectors/retinanet.py
new file mode 100644
index 0000000000000000000000000000000000000000..41378e8bc74bf9d5cbc7e3e6630bb1e6657049f9
--- /dev/null
+++ b/mmdet/models/detectors/retinanet.py
@@ -0,0 +1,17 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class RetinaNet(SingleStageDetector):
+ """Implementation of `RetinaNet `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(RetinaNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/rpn.py b/mmdet/models/detectors/rpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a77294549d1c3dc7821063c3f3d08bb331fbe59
--- /dev/null
+++ b/mmdet/models/detectors/rpn.py
@@ -0,0 +1,154 @@
+import mmcv
+from mmcv.image import tensor2imgs
+
+from mmdet.core import bbox_mapping
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class RPN(BaseDetector):
+ """Implementation of Region Proposal Network."""
+
+ def __init__(self,
+ backbone,
+ neck,
+ rpn_head,
+ train_cfg,
+ test_cfg,
+ pretrained=None):
+ super(RPN, self).__init__()
+ self.backbone = build_backbone(backbone)
+ self.neck = build_neck(neck) if neck is not None else None
+ rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
+ rpn_head.update(train_cfg=rpn_train_cfg)
+ rpn_head.update(test_cfg=test_cfg.rpn)
+ self.rpn_head = build_head(rpn_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.init_weights(pretrained=pretrained)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in detector.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(RPN, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ if self.with_neck:
+ self.neck.init_weights()
+ self.rpn_head.init_weights()
+
+ def extract_feat(self, img):
+ """Extract features.
+
+ Args:
+ img (torch.Tensor): Image tensor with shape (n, c, h ,w).
+
+ Returns:
+ list[torch.Tensor]: Multi-level features that may have
+ different resolutions.
+ """
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Dummy forward function."""
+ x = self.extract_feat(img)
+ rpn_outs = self.rpn_head(x)
+ return rpn_outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes=None,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ if (isinstance(self.train_cfg.rpn, dict)
+ and self.train_cfg.rpn.get('debug', False)):
+ self.rpn_head.debug_imgs = tensor2imgs(img)
+
+ x = self.extract_feat(img)
+ losses = self.rpn_head.forward_train(x, img_metas, gt_bboxes, None,
+ gt_bboxes_ignore)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[np.ndarray]: proposals
+ """
+ x = self.extract_feat(img)
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
+ if rescale:
+ for proposals, meta in zip(proposal_list, img_metas):
+ proposals[:, :4] /= proposals.new_tensor(meta['scale_factor'])
+
+ return [proposal.cpu().numpy() for proposal in proposal_list]
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[np.ndarray]: proposals
+ """
+ proposal_list = self.rpn_head.aug_test_rpn(
+ self.extract_feats(imgs), img_metas)
+ if not rescale:
+ for proposals, img_meta in zip(proposal_list, img_metas[0]):
+ img_shape = img_meta['img_shape']
+ scale_factor = img_meta['scale_factor']
+ flip = img_meta['flip']
+ flip_direction = img_meta['flip_direction']
+ proposals[:, :4] = bbox_mapping(proposals[:, :4], img_shape,
+ scale_factor, flip,
+ flip_direction)
+ return [proposal.cpu().numpy() for proposal in proposal_list]
+
+ def show_result(self, data, result, top_k=20, **kwargs):
+ """Show RPN proposals on the image.
+
+ Args:
+ data (str or np.ndarray): Image filename or loaded image.
+ result (Tensor or tuple): The results to draw over `img`
+ bbox_result or (bbox_result, segm_result).
+ top_k (int): Plot the first k bboxes only
+ if set positive. Default: 20
+
+ Returns:
+ np.ndarray: The image with bboxes drawn on it.
+ """
+ mmcv.imshow_bboxes(data, result, top_k=top_k)
diff --git a/mmdet/models/detectors/scnet.py b/mmdet/models/detectors/scnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..04a2347c4ec1efcbfda59a134cddd8bde620d983
--- /dev/null
+++ b/mmdet/models/detectors/scnet.py
@@ -0,0 +1,10 @@
+from ..builder import DETECTORS
+from .cascade_rcnn import CascadeRCNN
+
+
+@DETECTORS.register_module()
+class SCNet(CascadeRCNN):
+ """Implementation of `SCNet `_"""
+
+ def __init__(self, **kwargs):
+ super(SCNet, self).__init__(**kwargs)
diff --git a/mmdet/models/detectors/single_stage.py b/mmdet/models/detectors/single_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..5172bdbd945889445eeaa18398c9f0118bb845ad
--- /dev/null
+++ b/mmdet/models/detectors/single_stage.py
@@ -0,0 +1,154 @@
+import torch
+import torch.nn as nn
+
+from mmdet.core import bbox2result
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class SingleStageDetector(BaseDetector):
+ """Base class for single-stage detectors.
+
+ Single-stage detectors directly and densely predict bounding boxes on the
+ output features of the backbone+neck.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ bbox_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(SingleStageDetector, self).__init__()
+ self.backbone = build_backbone(backbone)
+ if neck is not None:
+ self.neck = build_neck(neck)
+ bbox_head.update(train_cfg=train_cfg)
+ bbox_head.update(test_cfg=test_cfg)
+ self.bbox_head = build_head(bbox_head)
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ self.init_weights(pretrained=pretrained)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in detector.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(SingleStageDetector, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ if self.with_neck:
+ if isinstance(self.neck, nn.Sequential):
+ for m in self.neck:
+ m.init_weights()
+ else:
+ self.neck.init_weights()
+ self.bbox_head.init_weights()
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ return outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None):
+ """
+ Args:
+ img (Tensor): Input images of shape (N, C, H, W).
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): A List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): Specify which bounding
+ boxes can be ignored when computing the loss.
+
+ Returns:
+ dict[str, Tensor]: A dictionary of loss components.
+ """
+ super(SingleStageDetector, self).forward_train(img, img_metas)
+ x = self.extract_feat(img)
+ losses = self.bbox_head.forward_train(x, img_metas, gt_bboxes,
+ gt_labels, gt_bboxes_ignore)
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ x = self.extract_feat(img)
+ outs = self.bbox_head(x)
+ # get origin input shape to support onnx dynamic shape
+ if torch.onnx.is_in_onnx_export():
+ # get shape as tensor
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+ bbox_list = self.bbox_head.get_bboxes(
+ *outs, img_metas, rescale=rescale)
+ # skip post-processing when exporting to ONNX
+ if torch.onnx.is_in_onnx_export():
+ return bbox_list
+
+ bbox_results = [
+ bbox2result(det_bboxes, det_labels, self.bbox_head.num_classes)
+ for det_bboxes, det_labels in bbox_list
+ ]
+ return bbox_results
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test function with test time augmentation.
+
+ Args:
+ imgs (list[Tensor]): the outer list indicates test-time
+ augmentations and inner Tensor should have a shape NxCxHxW,
+ which contains all images in the batch.
+ img_metas (list[list[dict]]): the outer list indicates test-time
+ augs (multiscale, flip, etc.) and the inner list indicates
+ images in a batch. each dict has image information.
+ rescale (bool, optional): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ assert hasattr(self.bbox_head, 'aug_test'), \
+ f'{self.bbox_head.__class__.__name__}' \
+ ' does not support test-time augmentation'
+
+ feats = self.extract_feats(imgs)
+ return [self.bbox_head.aug_test(feats, img_metas, rescale=rescale)]
diff --git a/mmdet/models/detectors/sparse_rcnn.py b/mmdet/models/detectors/sparse_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dbd0250f189e610a0bbc72b0dab2559e26857ae
--- /dev/null
+++ b/mmdet/models/detectors/sparse_rcnn.py
@@ -0,0 +1,110 @@
+from ..builder import DETECTORS
+from .two_stage import TwoStageDetector
+
+
+@DETECTORS.register_module()
+class SparseRCNN(TwoStageDetector):
+ r"""Implementation of `Sparse R-CNN: End-to-End Object Detection with
+ Learnable Proposals `_"""
+
+ def __init__(self, *args, **kwargs):
+ super(SparseRCNN, self).__init__(*args, **kwargs)
+ assert self.with_rpn, 'Sparse R-CNN do not support external proposals'
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ proposals=None,
+ **kwargs):
+ """Forward function of SparseR-CNN in train stage.
+
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ :class:`mmdet.datasets.pipelines.Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (List[Tensor], optional) : Segmentation masks for
+ each box. But we don't support it in this architecture.
+ proposals (List[Tensor], optional): override rpn proposals with
+ custom proposals. Use when `with_rpn` is False.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+
+ assert proposals is None, 'Sparse R-CNN does not support' \
+ ' external proposals'
+ assert gt_masks is None, 'Sparse R-CNN does not instance segmentation'
+
+ x = self.extract_feat(img)
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.forward_train(x, img_metas)
+ roi_losses = self.roi_head.forward_train(
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ gt_masks=gt_masks,
+ imgs_whwh=imgs_whwh)
+ return roi_losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation.
+
+ Args:
+ imgs (list[torch.Tensor]): List of multiple images
+ img_metas (list[dict]): List of image information.
+ rescale (bool): Whether to rescale the results.
+ Defaults to False.
+
+ Returns:
+ list[list[np.ndarray]]: BBox results of each image and classes.
+ The outer list corresponds to each image. The inner list
+ corresponds to each class.
+ """
+ x = self.extract_feat(img)
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.simple_test_rpn(x, img_metas)
+ bbox_results = self.roi_head.simple_test(
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ imgs_whwh=imgs_whwh,
+ rescale=rescale)
+ return bbox_results
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ # backbone
+ x = self.extract_feat(img)
+ # rpn
+ num_imgs = len(img)
+ dummy_img_metas = [
+ dict(img_shape=(800, 1333, 3)) for _ in range(num_imgs)
+ ]
+ proposal_boxes, proposal_features, imgs_whwh = \
+ self.rpn_head.simple_test_rpn(x, dummy_img_metas)
+ # roi_head
+ roi_outs = self.roi_head.forward_dummy(x, proposal_boxes,
+ proposal_features,
+ dummy_img_metas)
+ return roi_outs
diff --git a/mmdet/models/detectors/trident_faster_rcnn.py b/mmdet/models/detectors/trident_faster_rcnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0fd80d41407162df71ba5349fc659d4713cdb6e
--- /dev/null
+++ b/mmdet/models/detectors/trident_faster_rcnn.py
@@ -0,0 +1,66 @@
+from ..builder import DETECTORS
+from .faster_rcnn import FasterRCNN
+
+
+@DETECTORS.register_module()
+class TridentFasterRCNN(FasterRCNN):
+ """Implementation of `TridentNet `_"""
+
+ def __init__(self,
+ backbone,
+ rpn_head,
+ roi_head,
+ train_cfg,
+ test_cfg,
+ neck=None,
+ pretrained=None):
+
+ super(TridentFasterRCNN, self).__init__(
+ backbone=backbone,
+ neck=neck,
+ rpn_head=rpn_head,
+ roi_head=roi_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg,
+ pretrained=pretrained)
+ assert self.backbone.num_branch == self.roi_head.num_branch
+ assert self.backbone.test_branch_idx == self.roi_head.test_branch_idx
+ self.num_branch = self.backbone.num_branch
+ self.test_branch_idx = self.backbone.test_branch_idx
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ x = self.extract_feat(img)
+ if proposals is None:
+ num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
+ trident_img_metas = img_metas * num_branch
+ proposal_list = self.rpn_head.simple_test_rpn(x, trident_img_metas)
+ else:
+ proposal_list = proposals
+
+ return self.roi_head.simple_test(
+ x, proposal_list, trident_img_metas, rescale=rescale)
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ x = self.extract_feats(imgs)
+ num_branch = (self.num_branch if self.test_branch_idx == -1 else 1)
+ trident_img_metas = [img_metas * num_branch for img_metas in img_metas]
+ proposal_list = self.rpn_head.aug_test_rpn(x, trident_img_metas)
+ return self.roi_head.aug_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def forward_train(self, img, img_metas, gt_bboxes, gt_labels, **kwargs):
+ """make copies of img and gts to fit multi-branch."""
+ trident_gt_bboxes = tuple(gt_bboxes * self.num_branch)
+ trident_gt_labels = tuple(gt_labels * self.num_branch)
+ trident_img_metas = tuple(img_metas * self.num_branch)
+
+ return super(TridentFasterRCNN,
+ self).forward_train(img, trident_img_metas,
+ trident_gt_bboxes, trident_gt_labels)
diff --git a/mmdet/models/detectors/two_stage.py b/mmdet/models/detectors/two_stage.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba5bdde980dc0cd76375455c9c7ffaae4b25531e
--- /dev/null
+++ b/mmdet/models/detectors/two_stage.py
@@ -0,0 +1,215 @@
+import torch
+import torch.nn as nn
+
+# from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+from ..builder import DETECTORS, build_backbone, build_head, build_neck
+from .base import BaseDetector
+
+
+@DETECTORS.register_module()
+class TwoStageDetector(BaseDetector):
+ """Base class for two-stage detectors.
+
+ Two-stage detectors typically consisting of a region proposal network and a
+ task-specific regression head.
+ """
+
+ def __init__(self,
+ backbone,
+ neck=None,
+ rpn_head=None,
+ roi_head=None,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(TwoStageDetector, self).__init__()
+ self.backbone = build_backbone(backbone)
+
+ if neck is not None:
+ self.neck = build_neck(neck)
+
+ if rpn_head is not None:
+ rpn_train_cfg = train_cfg.rpn if train_cfg is not None else None
+ rpn_head_ = rpn_head.copy()
+ rpn_head_.update(train_cfg=rpn_train_cfg, test_cfg=test_cfg.rpn)
+ self.rpn_head = build_head(rpn_head_)
+
+ if roi_head is not None:
+ # update train and test cfg here for now
+ # TODO: refactor assigner & sampler
+ rcnn_train_cfg = train_cfg.rcnn if train_cfg is not None else None
+ roi_head.update(train_cfg=rcnn_train_cfg)
+ roi_head.update(test_cfg=test_cfg.rcnn)
+ self.roi_head = build_head(roi_head)
+
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+
+ self.init_weights(pretrained=pretrained)
+
+ @property
+ def with_rpn(self):
+ """bool: whether the detector has RPN"""
+ return hasattr(self, 'rpn_head') and self.rpn_head is not None
+
+ @property
+ def with_roi_head(self):
+ """bool: whether the detector has a RoI head"""
+ return hasattr(self, 'roi_head') and self.roi_head is not None
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in detector.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(TwoStageDetector, self).init_weights(pretrained)
+ self.backbone.init_weights(pretrained=pretrained)
+ if self.with_neck:
+ if isinstance(self.neck, nn.Sequential):
+ for m in self.neck:
+ m.init_weights()
+ else:
+ self.neck.init_weights()
+ if self.with_rpn:
+ self.rpn_head.init_weights()
+ if self.with_roi_head:
+ self.roi_head.init_weights(pretrained)
+
+ def extract_feat(self, img):
+ """Directly extract features from the backbone+neck."""
+ x = self.backbone(img)
+ if self.with_neck:
+ x = self.neck(x)
+ return x
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ outs = ()
+ # backbone
+ x = self.extract_feat(img)
+ # rpn
+ if self.with_rpn:
+ rpn_outs = self.rpn_head(x)
+ outs = outs + (rpn_outs, )
+ proposals = torch.randn(1000, 4).to(img.device)
+ # roi_head
+ roi_outs = self.roi_head.forward_dummy(x, proposals)
+ outs = outs + (roi_outs, )
+ return outs
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ proposals=None,
+ **kwargs):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ proposals : override rpn proposals with custom proposals. Use when
+ `with_rpn` is False.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ x = self.extract_feat(img)
+
+ losses = dict()
+
+ # RPN forward and loss
+ if self.with_rpn:
+ proposal_cfg = self.train_cfg.get('rpn_proposal',
+ self.test_cfg.rpn)
+ rpn_losses, proposal_list = self.rpn_head.forward_train(
+ x,
+ img_metas,
+ gt_bboxes,
+ gt_labels=None,
+ gt_bboxes_ignore=gt_bboxes_ignore,
+ proposal_cfg=proposal_cfg)
+ losses.update(rpn_losses)
+ else:
+ proposal_list = proposals
+
+ roi_losses = self.roi_head.forward_train(x, img_metas, proposal_list,
+ gt_bboxes, gt_labels,
+ gt_bboxes_ignore, gt_masks,
+ **kwargs)
+ losses.update(roi_losses)
+
+ return losses
+
+ async def async_simple_test(self,
+ img,
+ img_meta,
+ proposals=None,
+ rescale=False):
+ """Async test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ x = self.extract_feat(img)
+
+ if proposals is None:
+ proposal_list = await self.rpn_head.async_simple_test_rpn(
+ x, img_meta)
+ else:
+ proposal_list = proposals
+
+ return await self.roi_head.async_simple_test(
+ x, proposal_list, img_meta, rescale=rescale)
+
+ def simple_test(self, img, img_metas, proposals=None, rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ x = self.extract_feat(img)
+
+ # get origin input shape to onnx dynamic input shape
+ if torch.onnx.is_in_onnx_export():
+ img_shape = torch._shape_as_tensor(img)[2:]
+ img_metas[0]['img_shape_for_onnx'] = img_shape
+
+ if proposals is None:
+ proposal_list = self.rpn_head.simple_test_rpn(x, img_metas)
+ else:
+ proposal_list = proposals
+
+ return self.roi_head.simple_test(
+ x, proposal_list, img_metas, rescale=rescale)
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ x = self.extract_feats(imgs)
+ proposal_list = self.rpn_head.aug_test_rpn(x, img_metas)
+ return self.roi_head.aug_test(
+ x, proposal_list, img_metas, rescale=rescale)
diff --git a/mmdet/models/detectors/vfnet.py b/mmdet/models/detectors/vfnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..e23f89674c919921219ffd3486587a2d3c318fbd
--- /dev/null
+++ b/mmdet/models/detectors/vfnet.py
@@ -0,0 +1,18 @@
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class VFNet(SingleStageDetector):
+ """Implementation of `VarifocalNet
+ (VFNet).`_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(VFNet, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/detectors/yolact.py b/mmdet/models/detectors/yolact.py
new file mode 100644
index 0000000000000000000000000000000000000000..f32fde0d3dcbb55a405e05df433c4353938a148b
--- /dev/null
+++ b/mmdet/models/detectors/yolact.py
@@ -0,0 +1,146 @@
+import torch
+
+from mmdet.core import bbox2result
+from ..builder import DETECTORS, build_head
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLACT(SingleStageDetector):
+ """Implementation of `YOLACT `_"""
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ segm_head,
+ mask_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(YOLACT, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
+ self.segm_head = build_head(segm_head)
+ self.mask_head = build_head(mask_head)
+ self.init_segm_mask_weights()
+
+ def init_segm_mask_weights(self):
+ """Initialize weights of the YOLACT segm head and YOLACT mask head."""
+ self.segm_head.init_weights()
+ self.mask_head.init_weights()
+
+ def forward_dummy(self, img):
+ """Used for computing network flops.
+
+ See `mmdetection/tools/analysis_tools/get_flops.py`
+ """
+ raise NotImplementedError
+
+ def forward_train(self,
+ img,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """
+ Args:
+ img (Tensor): of shape (N, C, H, W) encoding input images.
+ Typically these should be mean centered and std scaled.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # convert Bitmap mask or Polygon Mask to Tensor here
+ gt_masks = [
+ gt_mask.to_tensor(dtype=torch.uint8, device=img.device)
+ for gt_mask in gt_masks
+ ]
+
+ x = self.extract_feat(img)
+
+ cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
+ bbox_head_loss_inputs = (cls_score, bbox_pred) + (gt_bboxes, gt_labels,
+ img_metas)
+ losses, sampling_results = self.bbox_head.loss(
+ *bbox_head_loss_inputs, gt_bboxes_ignore=gt_bboxes_ignore)
+
+ segm_head_outs = self.segm_head(x[0])
+ loss_segm = self.segm_head.loss(segm_head_outs, gt_masks, gt_labels)
+ losses.update(loss_segm)
+
+ mask_pred = self.mask_head(x[0], coeff_pred, gt_bboxes, img_metas,
+ sampling_results)
+ loss_mask = self.mask_head.loss(mask_pred, gt_masks, gt_bboxes,
+ img_metas, sampling_results)
+ losses.update(loss_mask)
+
+ # check NaN and Inf
+ for loss_name in losses.keys():
+ assert torch.isfinite(torch.stack(losses[loss_name]))\
+ .all().item(), '{} becomes infinite or NaN!'\
+ .format(loss_name)
+
+ return losses
+
+ def simple_test(self, img, img_metas, rescale=False):
+ """Test function without test time augmentation."""
+ x = self.extract_feat(img)
+
+ cls_score, bbox_pred, coeff_pred = self.bbox_head(x)
+
+ bbox_inputs = (cls_score, bbox_pred,
+ coeff_pred) + (img_metas, self.test_cfg, rescale)
+ det_bboxes, det_labels, det_coeffs = self.bbox_head.get_bboxes(
+ *bbox_inputs)
+ bbox_results = [
+ bbox2result(det_bbox, det_label, self.bbox_head.num_classes)
+ for det_bbox, det_label in zip(det_bboxes, det_labels)
+ ]
+
+ num_imgs = len(img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ segm_results = [[[] for _ in range(self.mask_head.num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_preds = self.mask_head(x[0], det_coeffs, _bboxes, img_metas)
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], det_labels[i], img_metas[i], rescale)
+ segm_results.append(segm_result)
+ return list(zip(bbox_results, segm_results))
+
+ def aug_test(self, imgs, img_metas, rescale=False):
+ """Test with augmentations."""
+ raise NotImplementedError
diff --git a/mmdet/models/detectors/yolo.py b/mmdet/models/detectors/yolo.py
new file mode 100644
index 0000000000000000000000000000000000000000..240aab20f857befe25e64114300ebb15a66c6a70
--- /dev/null
+++ b/mmdet/models/detectors/yolo.py
@@ -0,0 +1,18 @@
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+from ..builder import DETECTORS
+from .single_stage import SingleStageDetector
+
+
+@DETECTORS.register_module()
+class YOLOV3(SingleStageDetector):
+
+ def __init__(self,
+ backbone,
+ neck,
+ bbox_head,
+ train_cfg=None,
+ test_cfg=None,
+ pretrained=None):
+ super(YOLOV3, self).__init__(backbone, neck, bbox_head, train_cfg,
+ test_cfg, pretrained)
diff --git a/mmdet/models/losses/__init__.py b/mmdet/models/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..297aa228277768eb0ba0e8a377f19704d1feeca8
--- /dev/null
+++ b/mmdet/models/losses/__init__.py
@@ -0,0 +1,29 @@
+from .accuracy import Accuracy, accuracy
+from .ae_loss import AssociativeEmbeddingLoss
+from .balanced_l1_loss import BalancedL1Loss, balanced_l1_loss
+from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
+ cross_entropy, mask_cross_entropy)
+from .focal_loss import FocalLoss, sigmoid_focal_loss
+from .gaussian_focal_loss import GaussianFocalLoss
+from .gfocal_loss import DistributionFocalLoss, QualityFocalLoss
+from .ghm_loss import GHMC, GHMR
+from .iou_loss import (BoundedIoULoss, CIoULoss, DIoULoss, GIoULoss, IoULoss,
+ bounded_iou_loss, iou_loss)
+from .kd_loss import KnowledgeDistillationKLDivLoss
+from .mse_loss import MSELoss, mse_loss
+from .pisa_loss import carl_loss, isr_p
+from .smooth_l1_loss import L1Loss, SmoothL1Loss, l1_loss, smooth_l1_loss
+from .utils import reduce_loss, weight_reduce_loss, weighted_loss
+from .varifocal_loss import VarifocalLoss
+
+__all__ = [
+ 'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
+ 'mask_cross_entropy', 'CrossEntropyLoss', 'sigmoid_focal_loss',
+ 'FocalLoss', 'smooth_l1_loss', 'SmoothL1Loss', 'balanced_l1_loss',
+ 'BalancedL1Loss', 'mse_loss', 'MSELoss', 'iou_loss', 'bounded_iou_loss',
+ 'IoULoss', 'BoundedIoULoss', 'GIoULoss', 'DIoULoss', 'CIoULoss', 'GHMC',
+ 'GHMR', 'reduce_loss', 'weight_reduce_loss', 'weighted_loss', 'L1Loss',
+ 'l1_loss', 'isr_p', 'carl_loss', 'AssociativeEmbeddingLoss',
+ 'GaussianFocalLoss', 'QualityFocalLoss', 'DistributionFocalLoss',
+ 'VarifocalLoss', 'KnowledgeDistillationKLDivLoss'
+]
diff --git a/mmdet/models/losses/accuracy.py b/mmdet/models/losses/accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..789a2240a491289c5801b6690116e8ca657d004f
--- /dev/null
+++ b/mmdet/models/losses/accuracy.py
@@ -0,0 +1,78 @@
+import mmcv
+import torch.nn as nn
+
+
+@mmcv.jit(coderize=True)
+def accuracy(pred, target, topk=1, thresh=None):
+ """Calculate accuracy according to the prediction and target.
+
+ Args:
+ pred (torch.Tensor): The model prediction, shape (N, num_class)
+ target (torch.Tensor): The target of each prediction, shape (N, )
+ topk (int | tuple[int], optional): If the predictions in ``topk``
+ matches the target, the predictions will be regarded as
+ correct ones. Defaults to 1.
+ thresh (float, optional): If not None, predictions with scores under
+ this threshold are considered incorrect. Default to None.
+
+ Returns:
+ float | tuple[float]: If the input ``topk`` is a single integer,
+ the function will return a single float as accuracy. If
+ ``topk`` is a tuple containing multiple integers, the
+ function will return a tuple containing accuracies of
+ each ``topk`` number.
+ """
+ assert isinstance(topk, (int, tuple))
+ if isinstance(topk, int):
+ topk = (topk, )
+ return_single = True
+ else:
+ return_single = False
+
+ maxk = max(topk)
+ if pred.size(0) == 0:
+ accu = [pred.new_tensor(0.) for i in range(len(topk))]
+ return accu[0] if return_single else accu
+ assert pred.ndim == 2 and target.ndim == 1
+ assert pred.size(0) == target.size(0)
+ assert maxk <= pred.size(1), \
+ f'maxk {maxk} exceeds pred dimension {pred.size(1)}'
+ pred_value, pred_label = pred.topk(maxk, dim=1)
+ pred_label = pred_label.t() # transpose to shape (maxk, N)
+ correct = pred_label.eq(target.view(1, -1).expand_as(pred_label))
+ if thresh is not None:
+ # Only prediction values larger than thresh are counted as correct
+ correct = correct & (pred_value > thresh).t()
+ res = []
+ for k in topk:
+ correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
+ res.append(correct_k.mul_(100.0 / pred.size(0)))
+ return res[0] if return_single else res
+
+
+class Accuracy(nn.Module):
+
+ def __init__(self, topk=(1, ), thresh=None):
+ """Module to calculate the accuracy.
+
+ Args:
+ topk (tuple, optional): The criterion used to calculate the
+ accuracy. Defaults to (1,).
+ thresh (float, optional): If not None, predictions with scores
+ under this threshold are considered incorrect. Default to None.
+ """
+ super().__init__()
+ self.topk = topk
+ self.thresh = thresh
+
+ def forward(self, pred, target):
+ """Forward function to calculate accuracy.
+
+ Args:
+ pred (torch.Tensor): Prediction of models.
+ target (torch.Tensor): Target for each prediction.
+
+ Returns:
+ tuple[float]: The accuracies under different topk criterions.
+ """
+ return accuracy(pred, target, self.topk, self.thresh)
diff --git a/mmdet/models/losses/ae_loss.py b/mmdet/models/losses/ae_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..cff472aa03080fb49dbb3adba6fec68647a575e6
--- /dev/null
+++ b/mmdet/models/losses/ae_loss.py
@@ -0,0 +1,102 @@
+import mmcv
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def ae_loss_per_image(tl_preds, br_preds, match):
+ """Associative Embedding Loss in one image.
+
+ Associative Embedding Loss including two parts: pull loss and push loss.
+ Pull loss makes embedding vectors from same object closer to each other.
+ Push loss distinguish embedding vector from different objects, and makes
+ the gap between them is large enough.
+
+ During computing, usually there are 3 cases:
+ - no object in image: both pull loss and push loss will be 0.
+ - one object in image: push loss will be 0 and pull loss is computed
+ by the two corner of the only object.
+ - more than one objects in image: pull loss is computed by corner pairs
+ from each object, push loss is computed by each object with all
+ other objects. We use confusion matrix with 0 in diagonal to
+ compute the push loss.
+
+ Args:
+ tl_preds (tensor): Embedding feature map of left-top corner.
+ br_preds (tensor): Embedding feature map of bottim-right corner.
+ match (list): Downsampled coordinates pair of each ground truth box.
+ """
+
+ tl_list, br_list, me_list = [], [], []
+ if len(match) == 0: # no object in image
+ pull_loss = tl_preds.sum() * 0.
+ push_loss = tl_preds.sum() * 0.
+ else:
+ for m in match:
+ [tl_y, tl_x], [br_y, br_x] = m
+ tl_e = tl_preds[:, tl_y, tl_x].view(-1, 1)
+ br_e = br_preds[:, br_y, br_x].view(-1, 1)
+ tl_list.append(tl_e)
+ br_list.append(br_e)
+ me_list.append((tl_e + br_e) / 2.0)
+
+ tl_list = torch.cat(tl_list)
+ br_list = torch.cat(br_list)
+ me_list = torch.cat(me_list)
+
+ assert tl_list.size() == br_list.size()
+
+ # N is object number in image, M is dimension of embedding vector
+ N, M = tl_list.size()
+
+ pull_loss = (tl_list - me_list).pow(2) + (br_list - me_list).pow(2)
+ pull_loss = pull_loss.sum() / N
+
+ margin = 1 # exp setting of CornerNet, details in section 3.3 of paper
+
+ # confusion matrix of push loss
+ conf_mat = me_list.expand((N, N, M)).permute(1, 0, 2) - me_list
+ conf_weight = 1 - torch.eye(N).type_as(me_list)
+ conf_mat = conf_weight * (margin - conf_mat.sum(-1).abs())
+
+ if N > 1: # more than one object in current image
+ push_loss = F.relu(conf_mat).sum() / (N * (N - 1))
+ else:
+ push_loss = tl_preds.sum() * 0.
+
+ return pull_loss, push_loss
+
+
+@LOSSES.register_module()
+class AssociativeEmbeddingLoss(nn.Module):
+ """Associative Embedding Loss.
+
+ More details can be found in
+ `Associative Embedding `_ and
+ `CornerNet `_ .
+ Code is modified from `kp_utils.py `_ # noqa: E501
+
+ Args:
+ pull_weight (float): Loss weight for corners from same object.
+ push_weight (float): Loss weight for corners from different object.
+ """
+
+ def __init__(self, pull_weight=0.25, push_weight=0.25):
+ super(AssociativeEmbeddingLoss, self).__init__()
+ self.pull_weight = pull_weight
+ self.push_weight = push_weight
+
+ def forward(self, pred, target, match):
+ """Forward function."""
+ batch = pred.size(0)
+ pull_all, push_all = 0.0, 0.0
+ for i in range(batch):
+ pull, push = ae_loss_per_image(pred[i], target[i], match[i])
+
+ pull_all += self.pull_weight * pull
+ push_all += self.push_weight * push
+
+ return pull_all, push_all
diff --git a/mmdet/models/losses/balanced_l1_loss.py b/mmdet/models/losses/balanced_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bcd13ff26dbdc9f6eff8d7c7b5bde742a8d7d1d
--- /dev/null
+++ b/mmdet/models/losses/balanced_l1_loss.py
@@ -0,0 +1,120 @@
+import mmcv
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def balanced_l1_loss(pred,
+ target,
+ beta=1.0,
+ alpha=0.5,
+ gamma=1.5,
+ reduction='mean'):
+ """Calculate balanced L1 loss.
+
+ Please see the `Libra R-CNN `_
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 4).
+ target (torch.Tensor): The learning target of the prediction with
+ shape (N, 4).
+ beta (float): The loss is a piecewise function of prediction and target
+ and ``beta`` serves as a threshold for the difference between the
+ prediction and target. Defaults to 1.0.
+ alpha (float): The denominator ``alpha`` in the balanced L1 loss.
+ Defaults to 0.5.
+ gamma (float): The ``gamma`` in the balanced L1 loss.
+ Defaults to 1.5.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert beta > 0
+ assert pred.size() == target.size() and target.numel() > 0
+
+ diff = torch.abs(pred - target)
+ b = np.e**(gamma / alpha) - 1
+ loss = torch.where(
+ diff < beta, alpha / b *
+ (b * diff + 1) * torch.log(b * diff / beta + 1) - alpha * diff,
+ gamma * diff + gamma / b - alpha * beta)
+
+ return loss
+
+
+@LOSSES.register_module()
+class BalancedL1Loss(nn.Module):
+ """Balanced L1 Loss.
+
+ arXiv: https://arxiv.org/pdf/1904.02701.pdf (CVPR 2019)
+
+ Args:
+ alpha (float): The denominator ``alpha`` in the balanced L1 loss.
+ Defaults to 0.5.
+ gamma (float): The ``gamma`` in the balanced L1 loss. Defaults to 1.5.
+ beta (float, optional): The loss is a piecewise function of prediction
+ and target. ``beta`` serves as a threshold for the difference
+ between the prediction and target. Defaults to 1.0.
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self,
+ alpha=0.5,
+ gamma=1.5,
+ beta=1.0,
+ reduction='mean',
+ loss_weight=1.0):
+ super(BalancedL1Loss, self).__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 4).
+ target (torch.Tensor): The learning target of the prediction with
+ shape (N, 4).
+ weight (torch.Tensor, optional): Sample-wise loss weight with
+ shape (N, ).
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * balanced_l1_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_bbox
diff --git a/mmdet/models/losses/cross_entropy_loss.py b/mmdet/models/losses/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa908d2789e291616acf969912bf4429b1b07bf
--- /dev/null
+++ b/mmdet/models/losses/cross_entropy_loss.py
@@ -0,0 +1,216 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+def cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None):
+ """Calculate the CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ # element-wise losses
+ loss = F.cross_entropy(pred, label, weight=class_weight, reduction='none')
+
+ # apply weights and do the reduction
+ if weight is not None:
+ weight = weight.float()
+ loss = weight_reduce_loss(
+ loss, weight=weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def _expand_onehot_labels(labels, label_weights, label_channels):
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+ inds = torch.nonzero(
+ (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
+ if inds.numel() > 0:
+ bin_labels[inds, labels[inds]] = 1
+
+ if label_weights is None:
+ bin_label_weights = None
+ else:
+ bin_label_weights = label_weights.view(-1, 1).expand(
+ label_weights.size(0), label_channels)
+
+ return bin_labels, bin_label_weights
+
+
+def binary_cross_entropy(pred,
+ label,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None):
+ """Calculate the binary CrossEntropy loss.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, 1).
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ if pred.dim() != label.dim():
+ label, weight = _expand_onehot_labels(label, weight, pred.size(-1))
+
+ # weighted element-wise losses
+ if weight is not None:
+ weight = weight.float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, label.float(), pos_weight=class_weight, reduction='none')
+ # do the reduction for the weighted loss
+ loss = weight_reduce_loss(
+ loss, weight, reduction=reduction, avg_factor=avg_factor)
+
+ return loss
+
+
+def mask_cross_entropy(pred,
+ target,
+ label,
+ reduction='mean',
+ avg_factor=None,
+ class_weight=None):
+ """Calculate the CrossEntropy loss for masks.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C, *), C is the
+ number of classes. The trailing * indicates arbitrary shape.
+ target (torch.Tensor): The learning label of the prediction.
+ label (torch.Tensor): ``label`` indicates the class label of the mask
+ corresponding object. This will be used to select the mask in the
+ of the class which the object belongs to when the mask prediction
+ if not class-agnostic.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ class_weight (list[float], optional): The weight for each class.
+
+ Returns:
+ torch.Tensor: The calculated loss
+
+ Example:
+ >>> N, C = 3, 11
+ >>> H, W = 2, 2
+ >>> pred = torch.randn(N, C, H, W) * 1000
+ >>> target = torch.rand(N, H, W)
+ >>> label = torch.randint(0, C, size=(N,))
+ >>> reduction = 'mean'
+ >>> avg_factor = None
+ >>> class_weights = None
+ >>> loss = mask_cross_entropy(pred, target, label, reduction,
+ >>> avg_factor, class_weights)
+ >>> assert loss.shape == (1,)
+ """
+ # TODO: handle these two reserved arguments
+ assert reduction == 'mean' and avg_factor is None
+ num_rois = pred.size()[0]
+ inds = torch.arange(0, num_rois, dtype=torch.long, device=pred.device)
+ pred_slice = pred[inds, label].squeeze(1)
+ return F.binary_cross_entropy_with_logits(
+ pred_slice, target, weight=class_weight, reduction='mean')[None]
+
+
+@LOSSES.register_module()
+class CrossEntropyLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=False,
+ use_mask=False,
+ reduction='mean',
+ class_weight=None,
+ loss_weight=1.0):
+ """CrossEntropyLoss.
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction uses sigmoid
+ of softmax. Defaults to False.
+ use_mask (bool, optional): Whether to use mask cross entropy loss.
+ Defaults to False.
+ reduction (str, optional): . Defaults to 'mean'.
+ Options are "none", "mean" and "sum".
+ class_weight (list[float], optional): Weight of each class.
+ Defaults to None.
+ loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
+ """
+ super(CrossEntropyLoss, self).__init__()
+ assert (use_sigmoid is False) or (use_mask is False)
+ self.use_sigmoid = use_sigmoid
+ self.use_mask = use_mask
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.class_weight = class_weight
+
+ if self.use_sigmoid:
+ self.cls_criterion = binary_cross_entropy
+ elif self.use_mask:
+ self.cls_criterion = mask_cross_entropy
+ else:
+ self.cls_criterion = cross_entropy
+
+ def forward(self,
+ cls_score,
+ label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ cls_score (torch.Tensor): The prediction.
+ label (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction (str, optional): The method used to reduce the loss.
+ Options are "none", "mean" and "sum".
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.class_weight is not None:
+ class_weight = cls_score.new_tensor(
+ self.class_weight, device=cls_score.device)
+ else:
+ class_weight = None
+ loss_cls = self.loss_weight * self.cls_criterion(
+ cls_score,
+ label,
+ weight,
+ class_weight=class_weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_cls
+
+
diff --git a/mmdet/models/losses/focal_loss.py b/mmdet/models/losses/focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..493907c6984d532175e0351daf2eafe4b9ff0256
--- /dev/null
+++ b/mmdet/models/losses/focal_loss.py
@@ -0,0 +1,181 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.ops import sigmoid_focal_loss as _sigmoid_focal_loss
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+# This method is only for debugging
+def py_sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ """PyTorch version of `Focal Loss `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
+ number of classes
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
+ focal_weight = (alpha * target + (1 - alpha) *
+ (1 - target)) * pt.pow(gamma)
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none') * focal_weight
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+def sigmoid_focal_loss(pred,
+ target,
+ weight=None,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ avg_factor=None):
+ r"""A warpper of cuda version `Focal Loss
+ `_.
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the number
+ of classes.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): Sample-wise loss weight.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ # Function.apply does not accept keyword arguments, so the decorator
+ # "weighted_loss" is not applicable
+ loss = _sigmoid_focal_loss(pred.contiguous(), target, gamma, alpha, None,
+ 'none')
+ if weight is not None:
+ if weight.shape != loss.shape:
+ if weight.size(0) == loss.size(0):
+ # For most cases, weight is of shape (num_priors, ),
+ # which means it does not have the second axis num_class
+ weight = weight.view(-1, 1)
+ else:
+ # Sometimes, weight per anchor per class is also needed. e.g.
+ # in FSAF. But it may be flattened of shape
+ # (num_priors x num_class, ), while loss is still of shape
+ # (num_priors, num_class).
+ assert weight.numel() == loss.numel()
+ weight = weight.view(loss.size(0), -1)
+ assert weight.ndim == loss.ndim
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class FocalLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=True,
+ gamma=2.0,
+ alpha=0.25,
+ reduction='mean',
+ loss_weight=1.0):
+ """`Focal Loss `_
+
+ Args:
+ use_sigmoid (bool, optional): Whether to the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 0.25.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ """
+ super(FocalLoss, self).__init__()
+ assert use_sigmoid is True, 'Only sigmoid focal loss supported now.'
+ self.use_sigmoid = use_sigmoid
+ self.gamma = gamma
+ self.alpha = alpha
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning label of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ if torch.cuda.is_available() and pred.is_cuda:
+ calculate_loss_func = sigmoid_focal_loss
+ else:
+ num_classes = pred.size(1)
+ target = F.one_hot(target, num_classes=num_classes + 1)
+ target = target[:, :num_classes]
+ calculate_loss_func = py_sigmoid_focal_loss
+
+ loss_cls = self.loss_weight * calculate_loss_func(
+ pred,
+ target,
+ weight,
+ gamma=self.gamma,
+ alpha=self.alpha,
+ reduction=reduction,
+ avg_factor=avg_factor)
+
+ else:
+ raise NotImplementedError
+ return loss_cls
diff --git a/mmdet/models/losses/gaussian_focal_loss.py b/mmdet/models/losses/gaussian_focal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e45506a38e8e3c187be8288d0b714cc1ee29cf27
--- /dev/null
+++ b/mmdet/models/losses/gaussian_focal_loss.py
@@ -0,0 +1,91 @@
+import mmcv
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def gaussian_focal_loss(pred, gaussian_target, alpha=2.0, gamma=4.0):
+ """`Focal Loss `_ for targets in gaussian
+ distribution.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ gaussian_target (torch.Tensor): The learning target of the prediction
+ in gaussian distribution.
+ alpha (float, optional): A balanced form for Focal Loss.
+ Defaults to 2.0.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 4.0.
+ """
+ eps = 1e-12
+ pos_weights = gaussian_target.eq(1)
+ neg_weights = (1 - gaussian_target).pow(gamma)
+ pos_loss = -(pred + eps).log() * (1 - pred).pow(alpha) * pos_weights
+ neg_loss = -(1 - pred + eps).log() * pred.pow(alpha) * neg_weights
+ return pos_loss + neg_loss
+
+
+@LOSSES.register_module()
+class GaussianFocalLoss(nn.Module):
+ """GaussianFocalLoss is a variant of focal loss.
+
+ More details can be found in the `paper
+ `_
+ Code is modified from `kp_utils.py
+ `_ # noqa: E501
+ Please notice that the target in GaussianFocalLoss is a gaussian heatmap,
+ not 0/1 binary target.
+
+ Args:
+ alpha (float): Power of prediction.
+ gamma (float): Power of target for negative samples.
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Loss weight of current loss.
+ """
+
+ def __init__(self,
+ alpha=2.0,
+ gamma=4.0,
+ reduction='mean',
+ loss_weight=1.0):
+ super(GaussianFocalLoss, self).__init__()
+ self.alpha = alpha
+ self.gamma = gamma
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction
+ in gaussian distribution.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_reg = self.loss_weight * gaussian_focal_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ return loss_reg
diff --git a/mmdet/models/losses/gfocal_loss.py b/mmdet/models/losses/gfocal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d3b8833dc50c76f6741db5341dbf8da3402d07b
--- /dev/null
+++ b/mmdet/models/losses/gfocal_loss.py
@@ -0,0 +1,188 @@
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def quality_focal_loss(pred, target, beta=2.0):
+ r"""Quality Focal Loss (QFL) is from `Generalized Focal Loss: Learning
+ Qualified and Distributed Bounding Boxes for Dense Object Detection
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted joint representation of classification
+ and quality (IoU) estimation with shape (N, C), C is the number of
+ classes.
+ target (tuple([torch.Tensor])): Target category label with shape (N,)
+ and target quality label with shape (N,).
+ beta (float): The beta parameter for calculating the modulating factor.
+ Defaults to 2.0.
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ assert len(target) == 2, """target for QFL must be a tuple of two elements,
+ including category label and quality label, respectively"""
+ # label denotes the category id, score denotes the quality score
+ label, score = target
+
+ # negatives are supervised by 0 quality score
+ pred_sigmoid = pred.sigmoid()
+ scale_factor = pred_sigmoid
+ zerolabel = scale_factor.new_zeros(pred.shape)
+ loss = F.binary_cross_entropy_with_logits(
+ pred, zerolabel, reduction='none') * scale_factor.pow(beta)
+
+ # FG cat_id: [0, num_classes -1], BG cat_id: num_classes
+ bg_class_ind = pred.size(1)
+ pos = ((label >= 0) & (label < bg_class_ind)).nonzero().squeeze(1)
+ pos_label = label[pos].long()
+ # positives are supervised by bbox quality (IoU) score
+ scale_factor = score[pos] - pred_sigmoid[pos, pos_label]
+ loss[pos, pos_label] = F.binary_cross_entropy_with_logits(
+ pred[pos, pos_label], score[pos],
+ reduction='none') * scale_factor.abs().pow(beta)
+
+ loss = loss.sum(dim=1, keepdim=False)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def distribution_focal_loss(pred, label):
+ r"""Distribution Focal Loss (DFL) is from `Generalized Focal Loss: Learning
+ Qualified and Distributed Bounding Boxes for Dense Object Detection
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted general distribution of bounding boxes
+ (before softmax) with shape (N, n+1), n is the max value of the
+ integral set `{0, ..., n}` in paper.
+ label (torch.Tensor): Target distance label for bounding boxes with
+ shape (N,).
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ dis_left = label.long()
+ dis_right = dis_left + 1
+ weight_left = dis_right.float() - label
+ weight_right = label - dis_left.float()
+ loss = F.cross_entropy(pred, dis_left, reduction='none') * weight_left \
+ + F.cross_entropy(pred, dis_right, reduction='none') * weight_right
+ return loss
+
+
+@LOSSES.register_module()
+class QualityFocalLoss(nn.Module):
+ r"""Quality Focal Loss (QFL) is a variant of `Generalized Focal Loss:
+ Learning Qualified and Distributed Bounding Boxes for Dense Object
+ Detection `_.
+
+ Args:
+ use_sigmoid (bool): Whether sigmoid operation is conducted in QFL.
+ Defaults to True.
+ beta (float): The beta parameter for calculating the modulating factor.
+ Defaults to 2.0.
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Loss weight of current loss.
+ """
+
+ def __init__(self,
+ use_sigmoid=True,
+ beta=2.0,
+ reduction='mean',
+ loss_weight=1.0):
+ super(QualityFocalLoss, self).__init__()
+ assert use_sigmoid is True, 'Only sigmoid in QFL supported now.'
+ self.use_sigmoid = use_sigmoid
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): Predicted joint representation of
+ classification and quality (IoU) estimation with shape (N, C),
+ C is the number of classes.
+ target (tuple([torch.Tensor])): Target category label with shape
+ (N,) and target quality label with shape (N,).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ loss_cls = self.loss_weight * quality_focal_loss(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ else:
+ raise NotImplementedError
+ return loss_cls
+
+
+@LOSSES.register_module()
+class DistributionFocalLoss(nn.Module):
+ r"""Distribution Focal Loss (DFL) is a variant of `Generalized Focal Loss:
+ Learning Qualified and Distributed Bounding Boxes for Dense Object
+ Detection `_.
+
+ Args:
+ reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
+ loss_weight (float): Loss weight of current loss.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(DistributionFocalLoss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): Predicted general distribution of bounding
+ boxes (before softmax) with shape (N, n+1), n is the max value
+ of the integral set `{0, ..., n}` in paper.
+ target (torch.Tensor): Target distance label for bounding boxes
+ with shape (N,).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_cls = self.loss_weight * distribution_focal_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss_cls
diff --git a/mmdet/models/losses/ghm_loss.py b/mmdet/models/losses/ghm_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..8969a23fd98bb746415f96ac5e4ad9e37ba3af52
--- /dev/null
+++ b/mmdet/models/losses/ghm_loss.py
@@ -0,0 +1,172 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+
+
+def _expand_onehot_labels(labels, label_weights, label_channels):
+ bin_labels = labels.new_full((labels.size(0), label_channels), 0)
+ inds = torch.nonzero(
+ (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze()
+ if inds.numel() > 0:
+ bin_labels[inds, labels[inds]] = 1
+ bin_label_weights = label_weights.view(-1, 1).expand(
+ label_weights.size(0), label_channels)
+ return bin_labels, bin_label_weights
+
+
+# TODO: code refactoring to make it consistent with other losses
+@LOSSES.register_module()
+class GHMC(nn.Module):
+ """GHM Classification Loss.
+
+ Details of the theorem can be viewed in the paper
+ `Gradient Harmonized Single-stage Detector
+ `_.
+
+ Args:
+ bins (int): Number of the unit regions for distribution calculation.
+ momentum (float): The parameter for moving average.
+ use_sigmoid (bool): Can only be true for BCE based loss now.
+ loss_weight (float): The weight of the total GHM-C loss.
+ """
+
+ def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0):
+ super(GHMC, self).__init__()
+ self.bins = bins
+ self.momentum = momentum
+ edges = torch.arange(bins + 1).float() / bins
+ self.register_buffer('edges', edges)
+ self.edges[-1] += 1e-6
+ if momentum > 0:
+ acc_sum = torch.zeros(bins)
+ self.register_buffer('acc_sum', acc_sum)
+ self.use_sigmoid = use_sigmoid
+ if not self.use_sigmoid:
+ raise NotImplementedError
+ self.loss_weight = loss_weight
+
+ def forward(self, pred, target, label_weight, *args, **kwargs):
+ """Calculate the GHM-C loss.
+
+ Args:
+ pred (float tensor of size [batch_num, class_num]):
+ The direct prediction of classification fc layer.
+ target (float tensor of size [batch_num, class_num]):
+ Binary class target for each sample.
+ label_weight (float tensor of size [batch_num, class_num]):
+ the value is 1 if the sample is valid and 0 if ignored.
+ Returns:
+ The gradient harmonized loss.
+ """
+ # the target should be binary class label
+ if pred.dim() != target.dim():
+ target, label_weight = _expand_onehot_labels(
+ target, label_weight, pred.size(-1))
+ target, label_weight = target.float(), label_weight.float()
+ edges = self.edges
+ mmt = self.momentum
+ weights = torch.zeros_like(pred)
+
+ # gradient length
+ g = torch.abs(pred.sigmoid().detach() - target)
+
+ valid = label_weight > 0
+ tot = max(valid.float().sum().item(), 1.0)
+ n = 0 # n valid bins
+ for i in range(self.bins):
+ inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
+ num_in_bin = inds.sum().item()
+ if num_in_bin > 0:
+ if mmt > 0:
+ self.acc_sum[i] = mmt * self.acc_sum[i] \
+ + (1 - mmt) * num_in_bin
+ weights[inds] = tot / self.acc_sum[i]
+ else:
+ weights[inds] = tot / num_in_bin
+ n += 1
+ if n > 0:
+ weights = weights / n
+
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, weights, reduction='sum') / tot
+ return loss * self.loss_weight
+
+
+# TODO: code refactoring to make it consistent with other losses
+@LOSSES.register_module()
+class GHMR(nn.Module):
+ """GHM Regression Loss.
+
+ Details of the theorem can be viewed in the paper
+ `Gradient Harmonized Single-stage Detector
+ `_.
+
+ Args:
+ mu (float): The parameter for the Authentic Smooth L1 loss.
+ bins (int): Number of the unit regions for distribution calculation.
+ momentum (float): The parameter for moving average.
+ loss_weight (float): The weight of the total GHM-R loss.
+ """
+
+ def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0):
+ super(GHMR, self).__init__()
+ self.mu = mu
+ self.bins = bins
+ edges = torch.arange(bins + 1).float() / bins
+ self.register_buffer('edges', edges)
+ self.edges[-1] = 1e3
+ self.momentum = momentum
+ if momentum > 0:
+ acc_sum = torch.zeros(bins)
+ self.register_buffer('acc_sum', acc_sum)
+ self.loss_weight = loss_weight
+
+ # TODO: support reduction parameter
+ def forward(self, pred, target, label_weight, avg_factor=None):
+ """Calculate the GHM-R loss.
+
+ Args:
+ pred (float tensor of size [batch_num, 4 (* class_num)]):
+ The prediction of box regression layer. Channel number can be 4
+ or 4 * class_num depending on whether it is class-agnostic.
+ target (float tensor of size [batch_num, 4 (* class_num)]):
+ The target regression values with the same size of pred.
+ label_weight (float tensor of size [batch_num, 4 (* class_num)]):
+ The weight of each sample, 0 if ignored.
+ Returns:
+ The gradient harmonized loss.
+ """
+ mu = self.mu
+ edges = self.edges
+ mmt = self.momentum
+
+ # ASL1 loss
+ diff = pred - target
+ loss = torch.sqrt(diff * diff + mu * mu) - mu
+
+ # gradient length
+ g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach()
+ weights = torch.zeros_like(g)
+
+ valid = label_weight > 0
+ tot = max(label_weight.float().sum().item(), 1.0)
+ n = 0 # n: valid bins
+ for i in range(self.bins):
+ inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
+ num_in_bin = inds.sum().item()
+ if num_in_bin > 0:
+ n += 1
+ if mmt > 0:
+ self.acc_sum[i] = mmt * self.acc_sum[i] \
+ + (1 - mmt) * num_in_bin
+ weights[inds] = tot / self.acc_sum[i]
+ else:
+ weights[inds] = tot / num_in_bin
+ if n > 0:
+ weights /= n
+
+ loss = loss * weights
+ loss = loss.sum() / tot
+ return loss * self.loss_weight
diff --git a/mmdet/models/losses/iou_loss.py b/mmdet/models/losses/iou_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..eba6f18b80981ca891c1add37007e6bf478c651f
--- /dev/null
+++ b/mmdet/models/losses/iou_loss.py
@@ -0,0 +1,436 @@
+import math
+
+import mmcv
+import torch
+import torch.nn as nn
+
+from mmdet.core import bbox_overlaps
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def iou_loss(pred, target, linear=False, eps=1e-6):
+ """IoU loss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+ The loss is calculated as negative log of IoU.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ linear (bool, optional): If True, use linear scale of loss instead of
+ log scale. Default: False.
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ torch.Tensor: Loss tensor.
+ """
+ ious = bbox_overlaps(pred, target, is_aligned=True).clamp(min=eps)
+ if linear:
+ loss = 1 - ious
+ else:
+ loss = -ious.log()
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def bounded_iou_loss(pred, target, beta=0.2, eps=1e-3):
+ """BIoULoss.
+
+ This is an implementation of paper
+ `Improving Object Localization with Fitness NMS and Bounded IoU Loss.
+ `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes.
+ target (torch.Tensor): Target bboxes.
+ beta (float): beta parameter in smoothl1.
+ eps (float): eps to avoid NaN.
+ """
+ pred_ctrx = (pred[:, 0] + pred[:, 2]) * 0.5
+ pred_ctry = (pred[:, 1] + pred[:, 3]) * 0.5
+ pred_w = pred[:, 2] - pred[:, 0]
+ pred_h = pred[:, 3] - pred[:, 1]
+ with torch.no_grad():
+ target_ctrx = (target[:, 0] + target[:, 2]) * 0.5
+ target_ctry = (target[:, 1] + target[:, 3]) * 0.5
+ target_w = target[:, 2] - target[:, 0]
+ target_h = target[:, 3] - target[:, 1]
+
+ dx = target_ctrx - pred_ctrx
+ dy = target_ctry - pred_ctry
+
+ loss_dx = 1 - torch.max(
+ (target_w - 2 * dx.abs()) /
+ (target_w + 2 * dx.abs() + eps), torch.zeros_like(dx))
+ loss_dy = 1 - torch.max(
+ (target_h - 2 * dy.abs()) /
+ (target_h + 2 * dy.abs() + eps), torch.zeros_like(dy))
+ loss_dw = 1 - torch.min(target_w / (pred_w + eps), pred_w /
+ (target_w + eps))
+ loss_dh = 1 - torch.min(target_h / (pred_h + eps), pred_h /
+ (target_h + eps))
+ loss_comb = torch.stack([loss_dx, loss_dy, loss_dw, loss_dh],
+ dim=-1).view(loss_dx.size(0), -1)
+
+ loss = torch.where(loss_comb < beta, 0.5 * loss_comb * loss_comb / beta,
+ loss_comb - 0.5 * beta)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def giou_loss(pred, target, eps=1e-7):
+ r"""`Generalized Intersection over Union: A Metric and A Loss for Bounding
+ Box Regression `_.
+
+ Args:
+ pred (torch.Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (torch.Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+
+ Return:
+ Tensor: Loss tensor.
+ """
+ gious = bbox_overlaps(pred, target, mode='giou', is_aligned=True, eps=eps)
+ loss = 1 - gious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def diou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of Distance-IoU Loss: Faster and Better
+ Learning for Bounding Box Regression, https://arxiv.org/abs/1911.08287`_.
+
+ Code is modified from https://github.com/Zzh-tju/DIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ # DIoU
+ dious = ious - rho2 / c2
+ loss = 1 - dious
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def ciou_loss(pred, target, eps=1e-7):
+ r"""`Implementation of paper `Enhancing Geometric Factors into
+ Model Learning and Inference for Object Detection and Instance
+ Segmentation `_.
+
+ Code is modified from https://github.com/Zzh-tju/CIoU.
+
+ Args:
+ pred (Tensor): Predicted bboxes of format (x1, y1, x2, y2),
+ shape (n, 4).
+ target (Tensor): Corresponding gt bboxes, shape (n, 4).
+ eps (float): Eps to avoid log(0).
+ Return:
+ Tensor: Loss tensor.
+ """
+ # overlap
+ lt = torch.max(pred[:, :2], target[:, :2])
+ rb = torch.min(pred[:, 2:], target[:, 2:])
+ wh = (rb - lt).clamp(min=0)
+ overlap = wh[:, 0] * wh[:, 1]
+
+ # union
+ ap = (pred[:, 2] - pred[:, 0]) * (pred[:, 3] - pred[:, 1])
+ ag = (target[:, 2] - target[:, 0]) * (target[:, 3] - target[:, 1])
+ union = ap + ag - overlap + eps
+
+ # IoU
+ ious = overlap / union
+
+ # enclose area
+ enclose_x1y1 = torch.min(pred[:, :2], target[:, :2])
+ enclose_x2y2 = torch.max(pred[:, 2:], target[:, 2:])
+ enclose_wh = (enclose_x2y2 - enclose_x1y1).clamp(min=0)
+
+ cw = enclose_wh[:, 0]
+ ch = enclose_wh[:, 1]
+
+ c2 = cw**2 + ch**2 + eps
+
+ b1_x1, b1_y1 = pred[:, 0], pred[:, 1]
+ b1_x2, b1_y2 = pred[:, 2], pred[:, 3]
+ b2_x1, b2_y1 = target[:, 0], target[:, 1]
+ b2_x2, b2_y2 = target[:, 2], target[:, 3]
+
+ w1, h1 = b1_x2 - b1_x1, b1_y2 - b1_y1 + eps
+ w2, h2 = b2_x2 - b2_x1, b2_y2 - b2_y1 + eps
+
+ left = ((b2_x1 + b2_x2) - (b1_x1 + b1_x2))**2 / 4
+ right = ((b2_y1 + b2_y2) - (b1_y1 + b1_y2))**2 / 4
+ rho2 = left + right
+
+ factor = 4 / math.pi**2
+ v = factor * torch.pow(torch.atan(w2 / h2) - torch.atan(w1 / h1), 2)
+
+ # CIoU
+ cious = ious - (rho2 / c2 + v**2 / (1 - ious + v))
+ loss = 1 - cious
+ return loss
+
+
+@LOSSES.register_module()
+class IoULoss(nn.Module):
+ """IoULoss.
+
+ Computing the IoU loss between a set of predicted bboxes and target bboxes.
+
+ Args:
+ linear (bool): If True, use linear scale of loss instead of log scale.
+ Default: False.
+ eps (float): Eps to avoid log(0).
+ reduction (str): Options are "none", "mean" and "sum".
+ loss_weight (float): Weight of loss.
+ """
+
+ def __init__(self,
+ linear=False,
+ eps=1e-6,
+ reduction='mean',
+ loss_weight=1.0):
+ super(IoULoss, self).__init__()
+ self.linear = linear
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None. Options are "none", "mean" and "sum".
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if (weight is not None) and (not torch.any(weight > 0)) and (
+ reduction != 'none'):
+ return (pred * weight).sum() # 0
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # iou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * iou_loss(
+ pred,
+ target,
+ weight,
+ linear=self.linear,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class BoundedIoULoss(nn.Module):
+
+ def __init__(self, beta=0.2, eps=1e-3, reduction='mean', loss_weight=1.0):
+ super(BoundedIoULoss, self).__init__()
+ self.beta = beta
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss = self.loss_weight * bounded_iou_loss(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class GIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(GIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * giou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class DIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(DIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * diou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
+
+
+@LOSSES.register_module()
+class CIoULoss(nn.Module):
+
+ def __init__(self, eps=1e-6, reduction='mean', loss_weight=1.0):
+ super(CIoULoss, self).__init__()
+ self.eps = eps
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ if weight is not None and not torch.any(weight > 0):
+ return (pred * weight).sum() # 0
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if weight is not None and weight.dim() > 1:
+ # TODO: remove this in the future
+ # reduce the weight of shape (n, 4) to (n,) to match the
+ # giou_loss of shape (n,)
+ assert weight.shape == pred.shape
+ weight = weight.mean(-1)
+ loss = self.loss_weight * ciou_loss(
+ pred,
+ target,
+ weight,
+ eps=self.eps,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss
diff --git a/mmdet/models/losses/kd_loss.py b/mmdet/models/losses/kd_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3abb68d4f7b3eec98b873f69c1105a22eb33913
--- /dev/null
+++ b/mmdet/models/losses/kd_loss.py
@@ -0,0 +1,87 @@
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def knowledge_distillation_kl_div_loss(pred,
+ soft_label,
+ T,
+ detach_target=True):
+ r"""Loss function for knowledge distilling using KL divergence.
+
+ Args:
+ pred (Tensor): Predicted logits with shape (N, n + 1).
+ soft_label (Tensor): Target logits with shape (N, N + 1).
+ T (int): Temperature for distillation.
+ detach_target (bool): Remove soft_label from automatic differentiation
+
+ Returns:
+ torch.Tensor: Loss tensor with shape (N,).
+ """
+ assert pred.size() == soft_label.size()
+ target = F.softmax(soft_label / T, dim=1)
+ if detach_target:
+ target = target.detach()
+
+ kd_loss = F.kl_div(
+ F.log_softmax(pred / T, dim=1), target, reduction='none').mean(1) * (
+ T * T)
+
+ return kd_loss
+
+
+@LOSSES.register_module()
+class KnowledgeDistillationKLDivLoss(nn.Module):
+ """Loss function for knowledge distilling using KL divergence.
+
+ Args:
+ reduction (str): Options are `'none'`, `'mean'` and `'sum'`.
+ loss_weight (float): Loss weight of current loss.
+ T (int): Temperature for distillation.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0, T=10):
+ super(KnowledgeDistillationKLDivLoss, self).__init__()
+ assert T >= 1
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+ self.T = T
+
+ def forward(self,
+ pred,
+ soft_label,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (Tensor): Predicted logits with shape (N, n + 1).
+ soft_label (Tensor): Target logits with shape (N, N + 1).
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+
+ loss_kd = self.loss_weight * knowledge_distillation_kl_div_loss(
+ pred,
+ soft_label,
+ weight,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ T=self.T)
+
+ return loss_kd
diff --git a/mmdet/models/losses/mse_loss.py b/mmdet/models/losses/mse_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..68d05752a245548862f4c9919448d4fb8dc1b8ca
--- /dev/null
+++ b/mmdet/models/losses/mse_loss.py
@@ -0,0 +1,49 @@
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@weighted_loss
+def mse_loss(pred, target):
+ """Warpper of mse loss."""
+ return F.mse_loss(pred, target, reduction='none')
+
+
+@LOSSES.register_module()
+class MSELoss(nn.Module):
+ """MSELoss.
+
+ Args:
+ reduction (str, optional): The method that reduces the loss to a
+ scalar. Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of the loss. Defaults to 1.0
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super().__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self, pred, target, weight=None, avg_factor=None):
+ """Forward function of loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): Weight of the loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ loss = self.loss_weight * mse_loss(
+ pred,
+ target,
+ weight,
+ reduction=self.reduction,
+ avg_factor=avg_factor)
+ return loss
diff --git a/mmdet/models/losses/pisa_loss.py b/mmdet/models/losses/pisa_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a48adfcd400bb07b719a6fbd5a8af0508820629
--- /dev/null
+++ b/mmdet/models/losses/pisa_loss.py
@@ -0,0 +1,183 @@
+import mmcv
+import torch
+
+from mmdet.core import bbox_overlaps
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def isr_p(cls_score,
+ bbox_pred,
+ bbox_targets,
+ rois,
+ sampling_results,
+ loss_cls,
+ bbox_coder,
+ k=2,
+ bias=0,
+ num_class=80):
+ """Importance-based Sample Reweighting (ISR_P), positive part.
+
+ Args:
+ cls_score (Tensor): Predicted classification scores.
+ bbox_pred (Tensor): Predicted bbox deltas.
+ bbox_targets (tuple[Tensor]): A tuple of bbox targets, the are
+ labels, label_weights, bbox_targets, bbox_weights, respectively.
+ rois (Tensor): Anchors (single_stage) in shape (n, 4) or RoIs
+ (two_stage) in shape (n, 5).
+ sampling_results (obj): Sampling results.
+ loss_cls (func): Classification loss func of the head.
+ bbox_coder (obj): BBox coder of the head.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ num_class (int): Number of classes, default: 80.
+
+ Return:
+ tuple([Tensor]): labels, imp_based_label_weights, bbox_targets,
+ bbox_target_weights
+ """
+
+ labels, label_weights, bbox_targets, bbox_weights = bbox_targets
+ pos_label_inds = ((labels >= 0) &
+ (labels < num_class)).nonzero().reshape(-1)
+ pos_labels = labels[pos_label_inds]
+
+ # if no positive samples, return the original targets
+ num_pos = float(pos_label_inds.size(0))
+ if num_pos == 0:
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ # merge pos_assigned_gt_inds of per image to a single tensor
+ gts = list()
+ last_max_gt = 0
+ for i in range(len(sampling_results)):
+ gt_i = sampling_results[i].pos_assigned_gt_inds
+ gts.append(gt_i + last_max_gt)
+ if len(gt_i) != 0:
+ last_max_gt = gt_i.max() + 1
+ gts = torch.cat(gts)
+ assert len(gts) == num_pos
+
+ cls_score = cls_score.detach()
+ bbox_pred = bbox_pred.detach()
+
+ # For single stage detectors, rois here indicate anchors, in shape (N, 4)
+ # For two stage detectors, rois are in shape (N, 5)
+ if rois.size(-1) == 5:
+ pos_rois = rois[pos_label_inds][:, 1:]
+ else:
+ pos_rois = rois[pos_label_inds]
+
+ if bbox_pred.size(-1) > 4:
+ bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
+ pos_delta_pred = bbox_pred[pos_label_inds, pos_labels].view(-1, 4)
+ else:
+ pos_delta_pred = bbox_pred[pos_label_inds].view(-1, 4)
+
+ # compute iou of the predicted bbox and the corresponding GT
+ pos_delta_target = bbox_targets[pos_label_inds].view(-1, 4)
+ pos_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_pred)
+ target_bbox_pred = bbox_coder.decode(pos_rois, pos_delta_target)
+ ious = bbox_overlaps(pos_bbox_pred, target_bbox_pred, is_aligned=True)
+
+ pos_imp_weights = label_weights[pos_label_inds]
+ # Two steps to compute IoU-HLR. Samples are first sorted by IoU locally,
+ # then sorted again within the same-rank group
+ max_l_num = pos_labels.bincount().max()
+ for label in pos_labels.unique():
+ l_inds = (pos_labels == label).nonzero().view(-1)
+ l_gts = gts[l_inds]
+ for t in l_gts.unique():
+ t_inds = l_inds[l_gts == t]
+ t_ious = ious[t_inds]
+ _, t_iou_rank_idx = t_ious.sort(descending=True)
+ _, t_iou_rank = t_iou_rank_idx.sort()
+ ious[t_inds] += max_l_num - t_iou_rank.float()
+ l_ious = ious[l_inds]
+ _, l_iou_rank_idx = l_ious.sort(descending=True)
+ _, l_iou_rank = l_iou_rank_idx.sort() # IoU-HLR
+ # linearly map HLR to label weights
+ pos_imp_weights[l_inds] *= (max_l_num - l_iou_rank.float()) / max_l_num
+
+ pos_imp_weights = (bias + pos_imp_weights * (1 - bias)).pow(k)
+
+ # normalize to make the new weighted loss value equal to the original loss
+ pos_loss_cls = loss_cls(
+ cls_score[pos_label_inds], pos_labels, reduction_override='none')
+ if pos_loss_cls.dim() > 1:
+ ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds][:,
+ None]
+ new_pos_loss_cls = pos_loss_cls * pos_imp_weights[:, None]
+ else:
+ ori_pos_loss_cls = pos_loss_cls * label_weights[pos_label_inds]
+ new_pos_loss_cls = pos_loss_cls * pos_imp_weights
+ pos_loss_cls_ratio = ori_pos_loss_cls.sum() / new_pos_loss_cls.sum()
+ pos_imp_weights = pos_imp_weights * pos_loss_cls_ratio
+ label_weights[pos_label_inds] = pos_imp_weights
+
+ bbox_targets = labels, label_weights, bbox_targets, bbox_weights
+ return bbox_targets
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def carl_loss(cls_score,
+ labels,
+ bbox_pred,
+ bbox_targets,
+ loss_bbox,
+ k=1,
+ bias=0.2,
+ avg_factor=None,
+ sigmoid=False,
+ num_class=80):
+ """Classification-Aware Regression Loss (CARL).
+
+ Args:
+ cls_score (Tensor): Predicted classification scores.
+ labels (Tensor): Targets of classification.
+ bbox_pred (Tensor): Predicted bbox deltas.
+ bbox_targets (Tensor): Target of bbox regression.
+ loss_bbox (func): Regression loss func of the head.
+ bbox_coder (obj): BBox coder of the head.
+ k (float): Power of the non-linear mapping.
+ bias (float): Shift of the non-linear mapping.
+ avg_factor (int): Average factor used in regression loss.
+ sigmoid (bool): Activation of the classification score.
+ num_class (int): Number of classes, default: 80.
+
+ Return:
+ dict: CARL loss dict.
+ """
+ pos_label_inds = ((labels >= 0) &
+ (labels < num_class)).nonzero().reshape(-1)
+ if pos_label_inds.numel() == 0:
+ return dict(loss_carl=cls_score.sum()[None] * 0.)
+ pos_labels = labels[pos_label_inds]
+
+ # multiply pos_cls_score with the corresponding bbox weight
+ # and remain gradient
+ if sigmoid:
+ pos_cls_score = cls_score.sigmoid()[pos_label_inds, pos_labels]
+ else:
+ pos_cls_score = cls_score.softmax(-1)[pos_label_inds, pos_labels]
+ carl_loss_weights = (bias + (1 - bias) * pos_cls_score).pow(k)
+
+ # normalize carl_loss_weight to make its sum equal to num positive
+ num_pos = float(pos_cls_score.size(0))
+ weight_ratio = num_pos / carl_loss_weights.sum()
+ carl_loss_weights *= weight_ratio
+
+ if avg_factor is None:
+ avg_factor = bbox_targets.size(0)
+ # if is class agnostic, bbox pred is in shape (N, 4)
+ # otherwise, bbox pred is in shape (N, #classes, 4)
+ if bbox_pred.size(-1) > 4:
+ bbox_pred = bbox_pred.view(bbox_pred.size(0), -1, 4)
+ pos_bbox_preds = bbox_pred[pos_label_inds, pos_labels]
+ else:
+ pos_bbox_preds = bbox_pred[pos_label_inds]
+ ori_loss_reg = loss_bbox(
+ pos_bbox_preds,
+ bbox_targets[pos_label_inds],
+ reduction_override='none') / avg_factor
+ loss_carl = (ori_loss_reg * carl_loss_weights[:, None]).sum()
+ return dict(loss_carl=loss_carl[None])
diff --git a/mmdet/models/losses/smooth_l1_loss.py b/mmdet/models/losses/smooth_l1_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec9c98a52d1932d6ccff18938c17c36755bf1baf
--- /dev/null
+++ b/mmdet/models/losses/smooth_l1_loss.py
@@ -0,0 +1,139 @@
+import mmcv
+import torch
+import torch.nn as nn
+
+from ..builder import LOSSES
+from .utils import weighted_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def smooth_l1_loss(pred, target, beta=1.0):
+ """Smooth L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ assert beta > 0
+ assert pred.size() == target.size() and target.numel() > 0
+ diff = torch.abs(pred - target)
+ loss = torch.where(diff < beta, 0.5 * diff * diff / beta,
+ diff - 0.5 * beta)
+ return loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+@weighted_loss
+def l1_loss(pred, target):
+ """L1 loss.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+
+ Returns:
+ torch.Tensor: Calculated loss
+ """
+ assert pred.size() == target.size() and target.numel() > 0
+ loss = torch.abs(pred - target)
+ return loss
+
+
+@LOSSES.register_module()
+class SmoothL1Loss(nn.Module):
+ """Smooth L1 loss.
+
+ Args:
+ beta (float, optional): The threshold in the piecewise function.
+ Defaults to 1.0.
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum". Defaults to "mean".
+ loss_weight (float, optional): The weight of loss.
+ """
+
+ def __init__(self, beta=1.0, reduction='mean', loss_weight=1.0):
+ super(SmoothL1Loss, self).__init__()
+ self.beta = beta
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None,
+ **kwargs):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * smooth_l1_loss(
+ pred,
+ target,
+ weight,
+ beta=self.beta,
+ reduction=reduction,
+ avg_factor=avg_factor,
+ **kwargs)
+ return loss_bbox
+
+
+@LOSSES.register_module()
+class L1Loss(nn.Module):
+ """L1 loss.
+
+ Args:
+ reduction (str, optional): The method to reduce the loss.
+ Options are "none", "mean" and "sum".
+ loss_weight (float, optional): The weight of loss.
+ """
+
+ def __init__(self, reduction='mean', loss_weight=1.0):
+ super(L1Loss, self).__init__()
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Defaults to None.
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ loss_bbox = self.loss_weight * l1_loss(
+ pred, target, weight, reduction=reduction, avg_factor=avg_factor)
+ return loss_bbox
diff --git a/mmdet/models/losses/utils.py b/mmdet/models/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4756d7fcefd7cda1294c2662b4ca3e90c0a8e124
--- /dev/null
+++ b/mmdet/models/losses/utils.py
@@ -0,0 +1,100 @@
+import functools
+
+import mmcv
+import torch.nn.functional as F
+
+
+def reduce_loss(loss, reduction):
+ """Reduce loss as specified.
+
+ Args:
+ loss (Tensor): Elementwise loss tensor.
+ reduction (str): Options are "none", "mean" and "sum".
+
+ Return:
+ Tensor: Reduced loss tensor.
+ """
+ reduction_enum = F._Reduction.get_enum(reduction)
+ # none: 0, elementwise_mean:1, sum: 2
+ if reduction_enum == 0:
+ return loss
+ elif reduction_enum == 1:
+ return loss.mean()
+ elif reduction_enum == 2:
+ return loss.sum()
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def weight_reduce_loss(loss, weight=None, reduction='mean', avg_factor=None):
+ """Apply element-wise weight and reduce loss.
+
+ Args:
+ loss (Tensor): Element-wise loss.
+ weight (Tensor): Element-wise weights.
+ reduction (str): Same as built-in losses of PyTorch.
+ avg_factor (float): Avarage factor when computing the mean of losses.
+
+ Returns:
+ Tensor: Processed loss values.
+ """
+ # if weight is specified, apply element-wise weight
+ if weight is not None:
+ loss = loss * weight
+
+ # if avg_factor is not specified, just reduce the loss
+ if avg_factor is None:
+ loss = reduce_loss(loss, reduction)
+ else:
+ # if reduction is mean, then average the loss by avg_factor
+ if reduction == 'mean':
+ loss = loss.sum() / avg_factor
+ # if reduction is 'none', then do nothing, otherwise raise an error
+ elif reduction != 'none':
+ raise ValueError('avg_factor can not be used with reduction="sum"')
+ return loss
+
+
+def weighted_loss(loss_func):
+ """Create a weighted version of a given loss function.
+
+ To use this decorator, the loss function must have the signature like
+ `loss_func(pred, target, **kwargs)`. The function only needs to compute
+ element-wise loss without any reduction. This decorator will add weight
+ and reduction arguments to the function. The decorated function will have
+ the signature like `loss_func(pred, target, weight=None, reduction='mean',
+ avg_factor=None, **kwargs)`.
+
+ :Example:
+
+ >>> import torch
+ >>> @weighted_loss
+ >>> def l1_loss(pred, target):
+ >>> return (pred - target).abs()
+
+ >>> pred = torch.Tensor([0, 2, 3])
+ >>> target = torch.Tensor([1, 1, 1])
+ >>> weight = torch.Tensor([1, 0, 1])
+
+ >>> l1_loss(pred, target)
+ tensor(1.3333)
+ >>> l1_loss(pred, target, weight)
+ tensor(1.)
+ >>> l1_loss(pred, target, reduction='none')
+ tensor([1., 1., 2.])
+ >>> l1_loss(pred, target, weight, avg_factor=2)
+ tensor(1.5000)
+ """
+
+ @functools.wraps(loss_func)
+ def wrapper(pred,
+ target,
+ weight=None,
+ reduction='mean',
+ avg_factor=None,
+ **kwargs):
+ # get element-wise loss
+ loss = loss_func(pred, target, **kwargs)
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+ return wrapper
diff --git a/mmdet/models/losses/varifocal_loss.py b/mmdet/models/losses/varifocal_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f00bd6916c04fef45a9aeecb50888266420daf9
--- /dev/null
+++ b/mmdet/models/losses/varifocal_loss.py
@@ -0,0 +1,133 @@
+import mmcv
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ..builder import LOSSES
+from .utils import weight_reduce_loss
+
+
+@mmcv.jit(derivate=True, coderize=True)
+def varifocal_loss(pred,
+ target,
+ weight=None,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ reduction='mean',
+ avg_factor=None):
+ """`Varifocal Loss `_
+
+ Args:
+ pred (torch.Tensor): The prediction with shape (N, C), C is the
+ number of classes
+ target (torch.Tensor): The learning target of the iou-aware
+ classification score with shape (N, C), C is the number of classes.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ alpha (float, optional): A balance factor for the negative part of
+ Varifocal Loss, which is different from the alpha of Focal Loss.
+ Defaults to 0.75.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ iou_weighted (bool, optional): Whether to weight the loss of the
+ positive example with the iou target. Defaults to True.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ """
+ # pred and target should be of the same size
+ assert pred.size() == target.size()
+ pred_sigmoid = pred.sigmoid()
+ target = target.type_as(pred)
+ if iou_weighted:
+ focal_weight = target * (target > 0.0).float() + \
+ alpha * (pred_sigmoid - target).abs().pow(gamma) * \
+ (target <= 0.0).float()
+ else:
+ focal_weight = (target > 0.0).float() + \
+ alpha * (pred_sigmoid - target).abs().pow(gamma) * \
+ (target <= 0.0).float()
+ loss = F.binary_cross_entropy_with_logits(
+ pred, target, reduction='none') * focal_weight
+ loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
+ return loss
+
+
+@LOSSES.register_module()
+class VarifocalLoss(nn.Module):
+
+ def __init__(self,
+ use_sigmoid=True,
+ alpha=0.75,
+ gamma=2.0,
+ iou_weighted=True,
+ reduction='mean',
+ loss_weight=1.0):
+ """`Varifocal Loss `_
+
+ Args:
+ use_sigmoid (bool, optional): Whether the prediction is
+ used for sigmoid or softmax. Defaults to True.
+ alpha (float, optional): A balance factor for the negative part of
+ Varifocal Loss, which is different from the alpha of Focal
+ Loss. Defaults to 0.75.
+ gamma (float, optional): The gamma for calculating the modulating
+ factor. Defaults to 2.0.
+ iou_weighted (bool, optional): Whether to weight the loss of the
+ positive examples with the iou target. Defaults to True.
+ reduction (str, optional): The method used to reduce the loss into
+ a scalar. Defaults to 'mean'. Options are "none", "mean" and
+ "sum".
+ loss_weight (float, optional): Weight of loss. Defaults to 1.0.
+ """
+ super(VarifocalLoss, self).__init__()
+ assert use_sigmoid is True, \
+ 'Only sigmoid varifocal loss supported now.'
+ assert alpha >= 0.0
+ self.use_sigmoid = use_sigmoid
+ self.alpha = alpha
+ self.gamma = gamma
+ self.iou_weighted = iou_weighted
+ self.reduction = reduction
+ self.loss_weight = loss_weight
+
+ def forward(self,
+ pred,
+ target,
+ weight=None,
+ avg_factor=None,
+ reduction_override=None):
+ """Forward function.
+
+ Args:
+ pred (torch.Tensor): The prediction.
+ target (torch.Tensor): The learning target of the prediction.
+ weight (torch.Tensor, optional): The weight of loss for each
+ prediction. Defaults to None.
+ avg_factor (int, optional): Average factor that is used to average
+ the loss. Defaults to None.
+ reduction_override (str, optional): The reduction method used to
+ override the original reduction method of the loss.
+ Options are "none", "mean" and "sum".
+
+ Returns:
+ torch.Tensor: The calculated loss
+ """
+ assert reduction_override in (None, 'none', 'mean', 'sum')
+ reduction = (
+ reduction_override if reduction_override else self.reduction)
+ if self.use_sigmoid:
+ loss_cls = self.loss_weight * varifocal_loss(
+ pred,
+ target,
+ weight,
+ alpha=self.alpha,
+ gamma=self.gamma,
+ iou_weighted=self.iou_weighted,
+ reduction=reduction,
+ avg_factor=avg_factor)
+ else:
+ raise NotImplementedError
+ return loss_cls
diff --git a/mmdet/models/necks/__init__.py b/mmdet/models/necks/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02f833a8a0f538a8c06fef622d1cadc1a1b66ea2
--- /dev/null
+++ b/mmdet/models/necks/__init__.py
@@ -0,0 +1,16 @@
+from .bfp import BFP
+from .channel_mapper import ChannelMapper
+from .fpg import FPG
+from .fpn import FPN
+from .fpn_carafe import FPN_CARAFE
+from .hrfpn import HRFPN
+from .nas_fpn import NASFPN
+from .nasfcos_fpn import NASFCOS_FPN
+from .pafpn import PAFPN
+from .rfp import RFP
+from .yolo_neck import YOLOV3Neck
+
+__all__ = [
+ 'FPN', 'BFP', 'ChannelMapper', 'HRFPN', 'NASFPN', 'FPN_CARAFE', 'PAFPN',
+ 'NASFCOS_FPN', 'RFP', 'YOLOV3Neck', 'FPG'
+]
diff --git a/mmdet/models/necks/bfp.py b/mmdet/models/necks/bfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..123f5515ab6b51867d5781aa1572a0810670235f
--- /dev/null
+++ b/mmdet/models/necks/bfp.py
@@ -0,0 +1,104 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, xavier_init
+from mmcv.cnn.bricks import NonLocal2d
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class BFP(nn.Module):
+ """BFP (Balanced Feature Pyramids)
+
+ BFP takes multi-level features as inputs and gather them into a single one,
+ then refine the gathered feature and scatter the refined results to
+ multi-level features. This module is used in Libra R-CNN (CVPR 2019), see
+ the paper `Libra R-CNN: Towards Balanced Learning for Object Detection
+ `_ for details.
+
+ Args:
+ in_channels (int): Number of input channels (feature maps of all levels
+ should have the same channels).
+ num_levels (int): Number of input feature levels.
+ conv_cfg (dict): The config dict for convolution layers.
+ norm_cfg (dict): The config dict for normalization layers.
+ refine_level (int): Index of integration and refine level of BSF in
+ multi-level features from bottom to top.
+ refine_type (str): Type of the refine op, currently support
+ [None, 'conv', 'non_local'].
+ """
+
+ def __init__(self,
+ in_channels,
+ num_levels,
+ refine_level=2,
+ refine_type=None,
+ conv_cfg=None,
+ norm_cfg=None):
+ super(BFP, self).__init__()
+ assert refine_type in [None, 'conv', 'non_local']
+
+ self.in_channels = in_channels
+ self.num_levels = num_levels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.refine_level = refine_level
+ self.refine_type = refine_type
+ assert 0 <= self.refine_level < self.num_levels
+
+ if self.refine_type == 'conv':
+ self.refine = ConvModule(
+ self.in_channels,
+ self.in_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ elif self.refine_type == 'non_local':
+ self.refine = NonLocal2d(
+ self.in_channels,
+ reduction=1,
+ use_scale=False,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def init_weights(self):
+ """Initialize the weights of FPN module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == self.num_levels
+
+ # step 1: gather multi-level features by resize and average
+ feats = []
+ gather_size = inputs[self.refine_level].size()[2:]
+ for i in range(self.num_levels):
+ if i < self.refine_level:
+ gathered = F.adaptive_max_pool2d(
+ inputs[i], output_size=gather_size)
+ else:
+ gathered = F.interpolate(
+ inputs[i], size=gather_size, mode='nearest')
+ feats.append(gathered)
+
+ bsf = sum(feats) / len(feats)
+
+ # step 2: refine gathered features
+ if self.refine_type is not None:
+ bsf = self.refine(bsf)
+
+ # step 3: scatter refined features to multi-levels by a residual path
+ outs = []
+ for i in range(self.num_levels):
+ out_size = inputs[i].size()[2:]
+ if i < self.refine_level:
+ residual = F.interpolate(bsf, size=out_size, mode='nearest')
+ else:
+ residual = F.adaptive_max_pool2d(bsf, output_size=out_size)
+ outs.append(residual + inputs[i])
+
+ return tuple(outs)
diff --git a/mmdet/models/necks/channel_mapper.py b/mmdet/models/necks/channel_mapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4f5ed44caefb1612df67785b1f4f0d9ec46ee93
--- /dev/null
+++ b/mmdet/models/necks/channel_mapper.py
@@ -0,0 +1,74 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, xavier_init
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class ChannelMapper(nn.Module):
+ r"""Channel Mapper to reduce/increase channels of backbone features.
+
+ This is used to reduce/increase channels of backbone features.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale).
+ kernel_size (int, optional): kernel_size for reducing channels (used
+ at each scale). Default: 3.
+ conv_cfg (dict, optional): Config dict for convolution layer.
+ Default: None.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: None.
+ act_cfg (dict, optional): Config dict for activation layer in
+ ConvModule. Default: dict(type='ReLU').
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = ChannelMapper(in_channels, 11, 3).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ kernel_size=3,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU')):
+ super(ChannelMapper, self).__init__()
+ assert isinstance(in_channels, list)
+
+ self.convs = nn.ModuleList()
+ for in_channel in in_channels:
+ self.convs.append(
+ ConvModule(
+ in_channel,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg))
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of ChannelMapper module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.convs)
+ outs = [self.convs[i](inputs[i]) for i in range(len(inputs))]
+ return tuple(outs)
diff --git a/mmdet/models/necks/fpg.py b/mmdet/models/necks/fpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8e0d163ccf8cef6211530ba6c1b4d558ff6403f
--- /dev/null
+++ b/mmdet/models/necks/fpg.py
@@ -0,0 +1,398 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, caffe2_xavier_init, constant_init, is_norm
+
+from ..builder import NECKS
+
+
+class Transition(nn.Module):
+ """Base class for transition.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ """
+
+ def __init__(self, in_channels, out_channels):
+ super().__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ def forward(x):
+ pass
+
+
+class UpInterpolationConv(Transition):
+ """A transition used for up-sampling.
+
+ Up-sample the input by interpolation then refines the feature by
+ a convolution layer.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ scale_factor (int): Up-sampling factor. Default: 2.
+ mode (int): Interpolation mode. Default: nearest.
+ align_corners (bool): Whether align corners when interpolation.
+ Default: None.
+ kernel_size (int): Kernel size for the conv. Default: 3.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ scale_factor=2,
+ mode='nearest',
+ align_corners=None,
+ kernel_size=3,
+ **kwargs):
+ super().__init__(in_channels, out_channels)
+ self.mode = mode
+ self.scale_factor = scale_factor
+ self.align_corners = align_corners
+ self.conv = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ **kwargs)
+
+ def forward(self, x):
+ x = F.interpolate(
+ x,
+ scale_factor=self.scale_factor,
+ mode=self.mode,
+ align_corners=self.align_corners)
+ x = self.conv(x)
+ return x
+
+
+class LastConv(Transition):
+ """A transition used for refining the output of the last stage.
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of output channels.
+ num_inputs (int): Number of inputs of the FPN features.
+ kernel_size (int): Kernel size for the conv. Default: 3.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_inputs,
+ kernel_size=3,
+ **kwargs):
+ super().__init__(in_channels, out_channels)
+ self.num_inputs = num_inputs
+ self.conv_out = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size,
+ padding=(kernel_size - 1) // 2,
+ **kwargs)
+
+ def forward(self, inputs):
+ assert len(inputs) == self.num_inputs
+ return self.conv_out(inputs[-1])
+
+
+@NECKS.register_module()
+class FPG(nn.Module):
+ """FPG.
+
+ Implementation of `Feature Pyramid Grids (FPG)
+ `_.
+ This implementation only gives the basic structure stated in the paper.
+ But users can implement different type of transitions to fully explore the
+ the potential power of the structure of FPG.
+
+ Args:
+ in_channels (int): Number of input channels (feature maps of all levels
+ should have the same channels).
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ stack_times (int): The number of times the pyramid architecture will
+ be stacked.
+ paths (list[str]): Specify the path order of each stack level.
+ Each element in the list should be either 'bu' (bottom-up) or
+ 'td' (top-down).
+ inter_channels (int): Number of inter channels.
+ same_up_trans (dict): Transition that goes down at the same stage.
+ same_down_trans (dict): Transition that goes up at the same stage.
+ across_lateral_trans (dict): Across-pathway same-stage
+ across_down_trans (dict): Across-pathway bottom-up connection.
+ across_up_trans (dict): Across-pathway top-down connection.
+ across_skip_trans (dict): Across-pathway skip connection.
+ output_trans (dict): Transition that trans the output of the
+ last stage.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ """
+
+ transition_types = {
+ 'conv': ConvModule,
+ 'interpolation_conv': UpInterpolationConv,
+ 'last_conv': LastConv,
+ }
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ stack_times,
+ paths,
+ inter_channels=None,
+ same_down_trans=None,
+ same_up_trans=dict(
+ type='conv', kernel_size=3, stride=2, padding=1),
+ across_lateral_trans=dict(type='conv', kernel_size=1),
+ across_down_trans=dict(type='conv', kernel_size=3),
+ across_up_trans=None,
+ across_skip_trans=dict(type='identity'),
+ output_trans=dict(type='last_conv', kernel_size=3),
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ norm_cfg=None,
+ skip_inds=None):
+ super(FPG, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ if inter_channels is None:
+ self.inter_channels = [out_channels for _ in range(num_outs)]
+ elif isinstance(inter_channels, int):
+ self.inter_channels = [inter_channels for _ in range(num_outs)]
+ else:
+ assert isinstance(inter_channels, list)
+ assert len(inter_channels) == num_outs
+ self.inter_channels = inter_channels
+ self.stack_times = stack_times
+ self.paths = paths
+ assert isinstance(paths, list) and len(paths) == stack_times
+ for d in paths:
+ assert d in ('bu', 'td')
+
+ self.same_down_trans = same_down_trans
+ self.same_up_trans = same_up_trans
+ self.across_lateral_trans = across_lateral_trans
+ self.across_down_trans = across_down_trans
+ self.across_up_trans = across_up_trans
+ self.output_trans = output_trans
+ self.across_skip_trans = across_skip_trans
+
+ self.with_bias = norm_cfg is None
+ # skip inds must be specified if across skip trans is not None
+ if self.across_skip_trans is not None:
+ skip_inds is not None
+ self.skip_inds = skip_inds
+ assert len(self.skip_inds[0]) <= self.stack_times
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ # build lateral 1x1 convs to reduce channels
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = nn.Conv2d(self.in_channels[i],
+ self.inter_channels[i - self.start_level], 1)
+ self.lateral_convs.append(l_conv)
+
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ if self.add_extra_convs:
+ fpn_idx = self.backbone_end_level - self.start_level + i
+ extra_conv = nn.Conv2d(
+ self.inter_channels[fpn_idx - 1],
+ self.inter_channels[fpn_idx],
+ 3,
+ stride=2,
+ padding=1)
+ self.extra_downsamples.append(extra_conv)
+ else:
+ self.extra_downsamples.append(nn.MaxPool2d(1, stride=2))
+
+ self.fpn_transitions = nn.ModuleList() # stack times
+ for s in range(self.stack_times):
+ stage_trans = nn.ModuleList() # num of feature levels
+ for i in range(self.num_outs):
+ # same, across_lateral, across_down, across_up
+ trans = nn.ModuleDict()
+ if s in self.skip_inds[i]:
+ stage_trans.append(trans)
+ continue
+ # build same-stage down trans (used in bottom-up paths)
+ if i == 0 or self.same_up_trans is None:
+ same_up_trans = None
+ else:
+ same_up_trans = self.build_trans(
+ self.same_up_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['same_up'] = same_up_trans
+ # build same-stage up trans (used in top-down paths)
+ if i == self.num_outs - 1 or self.same_down_trans is None:
+ same_down_trans = None
+ else:
+ same_down_trans = self.build_trans(
+ self.same_down_trans, self.inter_channels[i + 1],
+ self.inter_channels[i])
+ trans['same_down'] = same_down_trans
+ # build across lateral trans
+ across_lateral_trans = self.build_trans(
+ self.across_lateral_trans, self.inter_channels[i],
+ self.inter_channels[i])
+ trans['across_lateral'] = across_lateral_trans
+ # build across down trans
+ if i == self.num_outs - 1 or self.across_down_trans is None:
+ across_down_trans = None
+ else:
+ across_down_trans = self.build_trans(
+ self.across_down_trans, self.inter_channels[i + 1],
+ self.inter_channels[i])
+ trans['across_down'] = across_down_trans
+ # build across up trans
+ if i == 0 or self.across_up_trans is None:
+ across_up_trans = None
+ else:
+ across_up_trans = self.build_trans(
+ self.across_up_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['across_up'] = across_up_trans
+ if self.across_skip_trans is None:
+ across_skip_trans = None
+ else:
+ across_skip_trans = self.build_trans(
+ self.across_skip_trans, self.inter_channels[i - 1],
+ self.inter_channels[i])
+ trans['across_skip'] = across_skip_trans
+ # build across_skip trans
+ stage_trans.append(trans)
+ self.fpn_transitions.append(stage_trans)
+
+ self.output_transition = nn.ModuleList() # output levels
+ for i in range(self.num_outs):
+ trans = self.build_trans(
+ self.output_trans,
+ self.inter_channels[i],
+ self.out_channels,
+ num_inputs=self.stack_times + 1)
+ self.output_transition.append(trans)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def build_trans(self, cfg, in_channels, out_channels, **extra_args):
+ cfg_ = cfg.copy()
+ trans_type = cfg_.pop('type')
+ trans_cls = self.transition_types[trans_type]
+ return trans_cls(in_channels, out_channels, **cfg_, **extra_args)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ caffe2_xavier_init(m)
+ elif is_norm(m):
+ constant_init(m, 1.0)
+
+ def fuse(self, fuse_dict):
+ out = None
+ for item in fuse_dict.values():
+ if item is not None:
+ if out is None:
+ out = item
+ else:
+ out = out + item
+ return out
+
+ def forward(self, inputs):
+ assert len(inputs) == len(self.in_channels)
+
+ # build all levels from original feature maps
+ feats = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ for downsample in self.extra_downsamples:
+ feats.append(downsample(feats[-1]))
+
+ outs = [feats]
+
+ for i in range(self.stack_times):
+ current_outs = outs[-1]
+ next_outs = []
+ direction = self.paths[i]
+ for j in range(self.num_outs):
+ if i in self.skip_inds[j]:
+ next_outs.append(outs[-1][j])
+ continue
+ # feature level
+ if direction == 'td':
+ lvl = self.num_outs - j - 1
+ else:
+ lvl = j
+ # get transitions
+ if direction == 'td':
+ same_trans = self.fpn_transitions[i][lvl]['same_down']
+ else:
+ same_trans = self.fpn_transitions[i][lvl]['same_up']
+ across_lateral_trans = self.fpn_transitions[i][lvl][
+ 'across_lateral']
+ across_down_trans = self.fpn_transitions[i][lvl]['across_down']
+ across_up_trans = self.fpn_transitions[i][lvl]['across_up']
+ across_skip_trans = self.fpn_transitions[i][lvl]['across_skip']
+ # init output
+ to_fuse = dict(
+ same=None, lateral=None, across_up=None, across_down=None)
+ # same downsample/upsample
+ if same_trans is not None:
+ to_fuse['same'] = same_trans(next_outs[-1])
+ # across lateral
+ if across_lateral_trans is not None:
+ to_fuse['lateral'] = across_lateral_trans(
+ current_outs[lvl])
+ # across downsample
+ if lvl > 0 and across_up_trans is not None:
+ to_fuse['across_up'] = across_up_trans(current_outs[lvl -
+ 1])
+ # across upsample
+ if (lvl < self.num_outs - 1 and across_down_trans is not None):
+ to_fuse['across_down'] = across_down_trans(
+ current_outs[lvl + 1])
+ if across_skip_trans is not None:
+ to_fuse['across_skip'] = across_skip_trans(outs[0][lvl])
+ x = self.fuse(to_fuse)
+ next_outs.append(x)
+
+ if direction == 'td':
+ outs.append(next_outs[::-1])
+ else:
+ outs.append(next_outs)
+
+ # output trans
+ final_outs = []
+ for i in range(self.num_outs):
+ lvl_out_list = []
+ for s in range(len(outs)):
+ lvl_out_list.append(outs[s][i])
+ lvl_out = self.output_transition[i](lvl_out_list)
+ final_outs.append(lvl_out)
+
+ return final_outs
diff --git a/mmdet/models/necks/fpn.py b/mmdet/models/necks/fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e5dfe685964f06e7a66b63a13e66162e63fcafd
--- /dev/null
+++ b/mmdet/models/necks/fpn.py
@@ -0,0 +1,221 @@
+import warnings
+
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, xavier_init
+from mmcv.runner import auto_fp16
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN(nn.Module):
+ r"""Feature Pyramid Network.
+
+ This is an implementation of paper `Feature Pyramid Networks for Object
+ Detection `_.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool | str): If bool, it decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ If str, it specifies the source feature map of the extra convs.
+ Only the following options are allowed
+
+ - 'on_input': Last feat map of neck inputs (i.e. backbone feature).
+ - 'on_lateral': Last feature map after lateral convs.
+ - 'on_output': The last output feature map after fpn convs.
+ extra_convs_on_inputs (bool, deprecated): Whether to apply extra convs
+ on the original feature from the backbone. If True,
+ it is equivalent to `add_extra_convs='on_input'`. If False, it is
+ equivalent to set `add_extra_convs='on_output'`. Default to True.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ upsample_cfg (dict): Config dict for interpolate layer.
+ Default: `dict(mode='nearest')`
+
+ Example:
+ >>> import torch
+ >>> in_channels = [2, 3, 5, 7]
+ >>> scales = [340, 170, 84, 43]
+ >>> inputs = [torch.rand(1, c, s, s)
+ ... for c, s in zip(in_channels, scales)]
+ >>> self = FPN(in_channels, 11, len(in_channels)).eval()
+ >>> outputs = self.forward(inputs)
+ >>> for i in range(len(outputs)):
+ ... print(f'outputs[{i}].shape = {outputs[i].shape}')
+ outputs[0].shape = torch.Size([1, 11, 340, 340])
+ outputs[1].shape = torch.Size([1, 11, 170, 170])
+ outputs[2].shape = torch.Size([1, 11, 84, 84])
+ outputs[3].shape = torch.Size([1, 11, 43, 43])
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=True,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None,
+ upsample_cfg=dict(mode='nearest')):
+ super(FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.relu_before_extra_convs = relu_before_extra_convs
+ self.no_norm_on_lateral = no_norm_on_lateral
+ self.fp16_enabled = False
+ self.upsample_cfg = upsample_cfg.copy()
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+ assert isinstance(add_extra_convs, (str, bool))
+ if isinstance(add_extra_convs, str):
+ # Extra_convs_source choices: 'on_input', 'on_lateral', 'on_output'
+ assert add_extra_convs in ('on_input', 'on_lateral', 'on_output')
+ elif add_extra_convs: # True
+ if extra_convs_on_inputs:
+ # TODO: deprecate `extra_convs_on_inputs`
+ warnings.simplefilter('once')
+ warnings.warn(
+ '"extra_convs_on_inputs" will be deprecated in v2.9.0,'
+ 'Please use "add_extra_convs"', DeprecationWarning)
+ self.add_extra_convs = 'on_input'
+ else:
+ self.add_extra_convs = 'on_output'
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg if not self.no_norm_on_lateral else None,
+ act_cfg=act_cfg,
+ inplace=False)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ if self.add_extra_convs and extra_levels >= 1:
+ for i in range(extra_levels):
+ if i == 0 and self.add_extra_convs == 'on_input':
+ in_channels = self.in_channels[self.backbone_end_level - 1]
+ else:
+ in_channels = out_channels
+ extra_fpn_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.fpn_convs.append(extra_fpn_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of FPN module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ # In some cases, fixing `scale factor` (e.g. 2) is preferred, but
+ # it cannot co-exist with `size` in `F.interpolate`.
+ if 'scale_factor' in self.upsample_cfg:
+ laterals[i - 1] += F.interpolate(laterals[i],
+ **self.upsample_cfg)
+ else:
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, **self.upsample_cfg)
+
+ # build outputs
+ # part 1: from original levels
+ outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+ # part 2: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ extra_source = inputs[self.backbone_end_level - 1]
+ elif self.add_extra_convs == 'on_lateral':
+ extra_source = laterals[-1]
+ elif self.add_extra_convs == 'on_output':
+ extra_source = outs[-1]
+ else:
+ raise NotImplementedError
+ outs.append(self.fpn_convs[used_backbone_levels](extra_source))
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/mmdet/models/necks/fpn_carafe.py b/mmdet/models/necks/fpn_carafe.py
new file mode 100644
index 0000000000000000000000000000000000000000..302e6576df9914e49166539108d6048b78c1fe71
--- /dev/null
+++ b/mmdet/models/necks/fpn_carafe.py
@@ -0,0 +1,267 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_upsample_layer, xavier_init
+from mmcv.ops.carafe import CARAFEPack
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class FPN_CARAFE(nn.Module):
+ """FPN_CARAFE is a more flexible implementation of FPN. It allows more
+ choice for upsample methods during the top-down pathway.
+
+ It can reproduce the performance of ICCV 2019 paper
+ CARAFE: Content-Aware ReAssembly of FEatures
+ Please refer to https://arxiv.org/abs/1905.02188 for more details.
+
+ Args:
+ in_channels (list[int]): Number of channels for each input feature map.
+ out_channels (int): Output channels of feature pyramids.
+ num_outs (int): Number of output stages.
+ start_level (int): Start level of feature pyramids.
+ (Default: 0)
+ end_level (int): End level of feature pyramids.
+ (Default: -1 indicates the last level).
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ activate (str): Type of activation function in ConvModule
+ (Default: None indicates w/o activation).
+ order (dict): Order of components in ConvModule.
+ upsample (str): Type of upsample layer.
+ upsample_cfg (dict): Dictionary to construct and config upsample layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ norm_cfg=None,
+ act_cfg=None,
+ order=('conv', 'norm', 'act'),
+ upsample_cfg=dict(
+ type='carafe',
+ up_kernel=5,
+ up_group=1,
+ encoder_kernel=3,
+ encoder_dilation=1)):
+ super(FPN_CARAFE, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.with_bias = norm_cfg is None
+ self.upsample_cfg = upsample_cfg.copy()
+ self.upsample = self.upsample_cfg.get('type')
+ self.relu = nn.ReLU(inplace=False)
+
+ self.order = order
+ assert order in [('conv', 'norm', 'act'), ('act', 'conv', 'norm')]
+
+ assert self.upsample in [
+ 'nearest', 'bilinear', 'deconv', 'pixel_shuffle', 'carafe', None
+ ]
+ if self.upsample in ['deconv', 'pixel_shuffle']:
+ assert hasattr(
+ self.upsample_cfg,
+ 'upsample_kernel') and self.upsample_cfg.upsample_kernel > 0
+ self.upsample_kernel = self.upsample_cfg.pop('upsample_kernel')
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+
+ self.lateral_convs = nn.ModuleList()
+ self.fpn_convs = nn.ModuleList()
+ self.upsample_modules = nn.ModuleList()
+
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ if i != self.backbone_end_level - 1:
+ upsample_cfg_ = self.upsample_cfg.copy()
+ if self.upsample == 'deconv':
+ upsample_cfg_.update(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=self.upsample_kernel,
+ stride=2,
+ padding=(self.upsample_kernel - 1) // 2,
+ output_padding=(self.upsample_kernel - 1) // 2)
+ elif self.upsample == 'pixel_shuffle':
+ upsample_cfg_.update(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ scale_factor=2,
+ upsample_kernel=self.upsample_kernel)
+ elif self.upsample == 'carafe':
+ upsample_cfg_.update(channels=out_channels, scale_factor=2)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample == 'nearest' else False)
+ upsample_cfg_.update(
+ scale_factor=2,
+ mode=self.upsample,
+ align_corners=align_corners)
+ upsample_module = build_upsample_layer(upsample_cfg_)
+ self.upsample_modules.append(upsample_module)
+ self.lateral_convs.append(l_conv)
+ self.fpn_convs.append(fpn_conv)
+
+ # add extra conv layers (e.g., RetinaNet)
+ extra_out_levels = (
+ num_outs - self.backbone_end_level + self.start_level)
+ if extra_out_levels >= 1:
+ for i in range(extra_out_levels):
+ in_channels = (
+ self.in_channels[self.backbone_end_level -
+ 1] if i == 0 else out_channels)
+ extra_l_conv = ConvModule(
+ in_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ norm_cfg=norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ if self.upsample == 'deconv':
+ upsampler_cfg_ = dict(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ kernel_size=self.upsample_kernel,
+ stride=2,
+ padding=(self.upsample_kernel - 1) // 2,
+ output_padding=(self.upsample_kernel - 1) // 2)
+ elif self.upsample == 'pixel_shuffle':
+ upsampler_cfg_ = dict(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ scale_factor=2,
+ upsample_kernel=self.upsample_kernel)
+ elif self.upsample == 'carafe':
+ upsampler_cfg_ = dict(
+ channels=out_channels,
+ scale_factor=2,
+ **self.upsample_cfg)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample == 'nearest' else False)
+ upsampler_cfg_ = dict(
+ scale_factor=2,
+ mode=self.upsample,
+ align_corners=align_corners)
+ upsampler_cfg_['type'] = self.upsample
+ upsample_module = build_upsample_layer(upsampler_cfg_)
+ extra_fpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ bias=self.with_bias,
+ act_cfg=act_cfg,
+ inplace=False,
+ order=self.order)
+ self.upsample_modules.append(upsample_module)
+ self.fpn_convs.append(extra_fpn_conv)
+ self.lateral_convs.append(extra_l_conv)
+
+ # default init_weights for conv(msra) and norm in ConvModule
+ def init_weights(self):
+ """Initialize the weights of module."""
+ for m in self.modules():
+ if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)):
+ xavier_init(m, distribution='uniform')
+ for m in self.modules():
+ if isinstance(m, CARAFEPack):
+ m.init_weights()
+
+ def slice_as(self, src, dst):
+ """Slice ``src`` as ``dst``
+
+ Note:
+ ``src`` should have the same or larger size than ``dst``.
+
+ Args:
+ src (torch.Tensor): Tensors to be sliced.
+ dst (torch.Tensor): ``src`` will be sliced to have the same
+ size as ``dst``.
+
+ Returns:
+ torch.Tensor: Sliced tensor.
+ """
+ assert (src.size(2) >= dst.size(2)) and (src.size(3) >= dst.size(3))
+ if src.size(2) == dst.size(2) and src.size(3) == dst.size(3):
+ return src
+ else:
+ return src[:, :, :dst.size(2), :dst.size(3)]
+
+ def tensor_add(self, a, b):
+ """Add tensors ``a`` and ``b`` that might have different sizes."""
+ if a.size() == b.size():
+ c = a + b
+ else:
+ c = a + self.slice_as(b, a)
+ return c
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = []
+ for i, lateral_conv in enumerate(self.lateral_convs):
+ if i <= self.backbone_end_level - self.start_level:
+ input = inputs[min(i + self.start_level, len(inputs) - 1)]
+ else:
+ input = laterals[-1]
+ lateral = lateral_conv(input)
+ laterals.append(lateral)
+
+ # build top-down path
+ for i in range(len(laterals) - 1, 0, -1):
+ if self.upsample is not None:
+ upsample_feat = self.upsample_modules[i - 1](laterals[i])
+ else:
+ upsample_feat = laterals[i]
+ laterals[i - 1] = self.tensor_add(laterals[i - 1], upsample_feat)
+
+ # build outputs
+ num_conv_outs = len(self.fpn_convs)
+ outs = []
+ for i in range(num_conv_outs):
+ out = self.fpn_convs[i](laterals[i])
+ outs.append(out)
+ return tuple(outs)
diff --git a/mmdet/models/necks/hrfpn.py b/mmdet/models/necks/hrfpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4f194832fc4b6ea77ce54262fb8ffa8675fc4e
--- /dev/null
+++ b/mmdet/models/necks/hrfpn.py
@@ -0,0 +1,102 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, caffe2_xavier_init
+from torch.utils.checkpoint import checkpoint
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class HRFPN(nn.Module):
+ """HRFPN (High Resolution Feature Pyramids)
+
+ paper: `High-Resolution Representations for Labeling Pixels and Regions
+ `_.
+
+ Args:
+ in_channels (list): number of channels for each branch.
+ out_channels (int): output channels of feature pyramids.
+ num_outs (int): number of output stages.
+ pooling_type (str): pooling for generating feature pyramids
+ from {MAX, AVG}.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ with_cp (bool): Use checkpoint or not. Using checkpoint will save some
+ memory while slowing down the training speed.
+ stride (int): stride of 3x3 convolutional layers
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs=5,
+ pooling_type='AVG',
+ conv_cfg=None,
+ norm_cfg=None,
+ with_cp=False,
+ stride=1):
+ super(HRFPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.with_cp = with_cp
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ self.reduction_conv = ConvModule(
+ sum(in_channels),
+ out_channels,
+ kernel_size=1,
+ conv_cfg=self.conv_cfg,
+ act_cfg=None)
+
+ self.fpn_convs = nn.ModuleList()
+ for i in range(self.num_outs):
+ self.fpn_convs.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ kernel_size=3,
+ padding=1,
+ stride=stride,
+ conv_cfg=self.conv_cfg,
+ act_cfg=None))
+
+ if pooling_type == 'MAX':
+ self.pooling = F.max_pool2d
+ else:
+ self.pooling = F.avg_pool2d
+
+ def init_weights(self):
+ """Initialize the weights of module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ caffe2_xavier_init(m)
+
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == self.num_ins
+ outs = [inputs[0]]
+ for i in range(1, self.num_ins):
+ outs.append(
+ F.interpolate(inputs[i], scale_factor=2**i, mode='bilinear'))
+ out = torch.cat(outs, dim=1)
+ if out.requires_grad and self.with_cp:
+ out = checkpoint(self.reduction_conv, out)
+ else:
+ out = self.reduction_conv(out)
+ outs = [out]
+ for i in range(1, self.num_outs):
+ outs.append(self.pooling(out, kernel_size=2**i, stride=2**i))
+ outputs = []
+
+ for i in range(self.num_outs):
+ if outs[i].requires_grad and self.with_cp:
+ tmp_out = checkpoint(self.fpn_convs[i], outs[i])
+ else:
+ tmp_out = self.fpn_convs[i](outs[i])
+ outputs.append(tmp_out)
+ return tuple(outputs)
diff --git a/mmdet/models/necks/nas_fpn.py b/mmdet/models/necks/nas_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e333ce65d4d06c47c29af489526ba3142736ad7
--- /dev/null
+++ b/mmdet/models/necks/nas_fpn.py
@@ -0,0 +1,160 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, caffe2_xavier_init
+from mmcv.ops.merge_cells import GlobalPoolingCell, SumCell
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class NASFPN(nn.Module):
+ """NAS-FPN.
+
+ Implementation of `NAS-FPN: Learning Scalable Feature Pyramid Architecture
+ for Object Detection `_
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ stack_times (int): The number of times the pyramid architecture will
+ be stacked.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ stack_times,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ norm_cfg=None):
+ super(NASFPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels) # num of input feature levels
+ self.num_outs = num_outs # num of output feature levels
+ self.stack_times = stack_times
+ self.norm_cfg = norm_cfg
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ # if end_level < inputs, no extra level is allowed
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ # add lateral connections
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ l_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+ self.lateral_convs.append(l_conv)
+
+ # add extra downsample layers (stride-2 pooling or conv)
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ extra_conv = ConvModule(
+ out_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
+ self.extra_downsamples.append(
+ nn.Sequential(extra_conv, nn.MaxPool2d(2, 2)))
+
+ # add NAS FPN connections
+ self.fpn_stages = nn.ModuleList()
+ for _ in range(self.stack_times):
+ stage = nn.ModuleDict()
+ # gp(p6, p4) -> p4_1
+ stage['gp_64_4'] = GlobalPoolingCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p4_1, p4) -> p4_2
+ stage['sum_44_4'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p4_2, p3) -> p3_out
+ stage['sum_43_3'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p3_out, p4_2) -> p4_out
+ stage['sum_34_4'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p5, gp(p4_out, p3_out)) -> p5_out
+ stage['gp_43_5'] = GlobalPoolingCell(with_out_conv=False)
+ stage['sum_55_5'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # sum(p7, gp(p5_out, p4_2)) -> p7_out
+ stage['gp_54_7'] = GlobalPoolingCell(with_out_conv=False)
+ stage['sum_77_7'] = SumCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ # gp(p7_out, p5_out) -> p6_out
+ stage['gp_75_6'] = GlobalPoolingCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ out_norm_cfg=norm_cfg)
+ self.fpn_stages.append(stage)
+
+ def init_weights(self):
+ """Initialize the weights of module."""
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ caffe2_xavier_init(m)
+
+ def forward(self, inputs):
+ """Forward function."""
+ # build P3-P5
+ feats = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+ # build P6-P7 on top of P5
+ for downsample in self.extra_downsamples:
+ feats.append(downsample(feats[-1]))
+
+ p3, p4, p5, p6, p7 = feats
+
+ for stage in self.fpn_stages:
+ # gp(p6, p4) -> p4_1
+ p4_1 = stage['gp_64_4'](p6, p4, out_size=p4.shape[-2:])
+ # sum(p4_1, p4) -> p4_2
+ p4_2 = stage['sum_44_4'](p4_1, p4, out_size=p4.shape[-2:])
+ # sum(p4_2, p3) -> p3_out
+ p3 = stage['sum_43_3'](p4_2, p3, out_size=p3.shape[-2:])
+ # sum(p3_out, p4_2) -> p4_out
+ p4 = stage['sum_34_4'](p3, p4_2, out_size=p4.shape[-2:])
+ # sum(p5, gp(p4_out, p3_out)) -> p5_out
+ p5_tmp = stage['gp_43_5'](p4, p3, out_size=p5.shape[-2:])
+ p5 = stage['sum_55_5'](p5, p5_tmp, out_size=p5.shape[-2:])
+ # sum(p7, gp(p5_out, p4_2)) -> p7_out
+ p7_tmp = stage['gp_54_7'](p5, p4_2, out_size=p7.shape[-2:])
+ p7 = stage['sum_77_7'](p7, p7_tmp, out_size=p7.shape[-2:])
+ # gp(p7_out, p5_out) -> p6_out
+ p6 = stage['gp_75_6'](p7, p5, out_size=p6.shape[-2:])
+
+ return p3, p4, p5, p6, p7
diff --git a/mmdet/models/necks/nasfcos_fpn.py b/mmdet/models/necks/nasfcos_fpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..2daf79ef591373499184c624ccd27fb7456dec06
--- /dev/null
+++ b/mmdet/models/necks/nasfcos_fpn.py
@@ -0,0 +1,161 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, caffe2_xavier_init
+from mmcv.ops.merge_cells import ConcatCell
+
+from ..builder import NECKS
+
+
+@NECKS.register_module()
+class NASFCOS_FPN(nn.Module):
+ """FPN structure in NASFPN.
+
+ Implementation of paper `NAS-FCOS: Fast Neural Architecture Search for
+ Object Detection `_
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): It decides whether to add conv
+ layers on top of the original feature maps. Default to False.
+ If True, its actual mode is specified by `extra_convs_on_inputs`.
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=1,
+ end_level=-1,
+ add_extra_convs=False,
+ conv_cfg=None,
+ norm_cfg=None):
+ super(NASFCOS_FPN, self).__init__()
+ assert isinstance(in_channels, list)
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.num_ins = len(in_channels)
+ self.num_outs = num_outs
+ self.norm_cfg = norm_cfg
+ self.conv_cfg = conv_cfg
+
+ if end_level == -1:
+ self.backbone_end_level = self.num_ins
+ assert num_outs >= self.num_ins - start_level
+ else:
+ self.backbone_end_level = end_level
+ assert end_level <= len(in_channels)
+ assert num_outs == end_level - start_level
+ self.start_level = start_level
+ self.end_level = end_level
+ self.add_extra_convs = add_extra_convs
+
+ self.adapt_convs = nn.ModuleList()
+ for i in range(self.start_level, self.backbone_end_level):
+ adapt_conv = ConvModule(
+ in_channels[i],
+ out_channels,
+ 1,
+ stride=1,
+ padding=0,
+ bias=False,
+ norm_cfg=dict(type='BN'),
+ act_cfg=dict(type='ReLU', inplace=False))
+ self.adapt_convs.append(adapt_conv)
+
+ # C2 is omitted according to the paper
+ extra_levels = num_outs - self.backbone_end_level + self.start_level
+
+ def build_concat_cell(with_input1_conv, with_input2_conv):
+ cell_conv_cfg = dict(
+ kernel_size=1, padding=0, bias=False, groups=out_channels)
+ return ConcatCell(
+ in_channels=out_channels,
+ out_channels=out_channels,
+ with_out_conv=True,
+ out_conv_cfg=cell_conv_cfg,
+ out_norm_cfg=dict(type='BN'),
+ out_conv_order=('norm', 'act', 'conv'),
+ with_input1_conv=with_input1_conv,
+ with_input2_conv=with_input2_conv,
+ input_conv_cfg=conv_cfg,
+ input_norm_cfg=norm_cfg,
+ upsample_mode='nearest')
+
+ # Denote c3=f0, c4=f1, c5=f2 for convince
+ self.fpn = nn.ModuleDict()
+ self.fpn['c22_1'] = build_concat_cell(True, True)
+ self.fpn['c22_2'] = build_concat_cell(True, True)
+ self.fpn['c32'] = build_concat_cell(True, False)
+ self.fpn['c02'] = build_concat_cell(True, False)
+ self.fpn['c42'] = build_concat_cell(True, True)
+ self.fpn['c36'] = build_concat_cell(True, True)
+ self.fpn['c61'] = build_concat_cell(True, True) # f9
+ self.extra_downsamples = nn.ModuleList()
+ for i in range(extra_levels):
+ extra_act_cfg = None if i == 0 \
+ else dict(type='ReLU', inplace=False)
+ self.extra_downsamples.append(
+ ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ act_cfg=extra_act_cfg,
+ order=('act', 'norm', 'conv')))
+
+ def forward(self, inputs):
+ """Forward function."""
+ feats = [
+ adapt_conv(inputs[i + self.start_level])
+ for i, adapt_conv in enumerate(self.adapt_convs)
+ ]
+
+ for (i, module_name) in enumerate(self.fpn):
+ idx_1, idx_2 = int(module_name[1]), int(module_name[2])
+ res = self.fpn[module_name](feats[idx_1], feats[idx_2])
+ feats.append(res)
+
+ ret = []
+ for (idx, input_idx) in zip([9, 8, 7], [1, 2, 3]): # add P3, P4, P5
+ feats1, feats2 = feats[idx], feats[5]
+ feats2_resize = F.interpolate(
+ feats2,
+ size=feats1.size()[2:],
+ mode='bilinear',
+ align_corners=False)
+
+ feats_sum = feats1 + feats2_resize
+ ret.append(
+ F.interpolate(
+ feats_sum,
+ size=inputs[input_idx].size()[2:],
+ mode='bilinear',
+ align_corners=False))
+
+ for submodule in self.extra_downsamples:
+ ret.append(submodule(ret[-1]))
+
+ return tuple(ret)
+
+ def init_weights(self):
+ """Initialize the weights of module."""
+ for module in self.fpn.values():
+ if hasattr(module, 'conv_out'):
+ caffe2_xavier_init(module.out_conv.conv)
+
+ for modules in [
+ self.adapt_convs.modules(),
+ self.extra_downsamples.modules()
+ ]:
+ for module in modules:
+ if isinstance(module, nn.Conv2d):
+ caffe2_xavier_init(module)
diff --git a/mmdet/models/necks/pafpn.py b/mmdet/models/necks/pafpn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7c0b50f29e882aacb5158b33ead3d4566d0ce0b
--- /dev/null
+++ b/mmdet/models/necks/pafpn.py
@@ -0,0 +1,142 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+from mmcv.runner import auto_fp16
+
+from ..builder import NECKS
+from .fpn import FPN
+
+
+@NECKS.register_module()
+class PAFPN(FPN):
+ """Path Aggregation Network for Instance Segmentation.
+
+ This is an implementation of the `PAFPN in Path Aggregation Network
+ `_.
+
+ Args:
+ in_channels (List[int]): Number of input channels per scale.
+ out_channels (int): Number of output channels (used at each scale)
+ num_outs (int): Number of output scales.
+ start_level (int): Index of the start input backbone level used to
+ build the feature pyramid. Default: 0.
+ end_level (int): Index of the end input backbone level (exclusive) to
+ build the feature pyramid. Default: -1, which means the last level.
+ add_extra_convs (bool): Whether to add conv layers on top of the
+ original feature maps. Default: False.
+ extra_convs_on_inputs (bool): Whether to apply extra conv on
+ the original feature from the backbone. Default: False.
+ relu_before_extra_convs (bool): Whether to apply relu before the extra
+ conv. Default: False.
+ no_norm_on_lateral (bool): Whether to apply norm on lateral.
+ Default: False.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Config dict for normalization layer. Default: None.
+ act_cfg (str): Config dict for activation layer in ConvModule.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ num_outs,
+ start_level=0,
+ end_level=-1,
+ add_extra_convs=False,
+ extra_convs_on_inputs=True,
+ relu_before_extra_convs=False,
+ no_norm_on_lateral=False,
+ conv_cfg=None,
+ norm_cfg=None,
+ act_cfg=None):
+ super(PAFPN,
+ self).__init__(in_channels, out_channels, num_outs, start_level,
+ end_level, add_extra_convs, extra_convs_on_inputs,
+ relu_before_extra_convs, no_norm_on_lateral,
+ conv_cfg, norm_cfg, act_cfg)
+ # add extra bottom up pathway
+ self.downsample_convs = nn.ModuleList()
+ self.pafpn_convs = nn.ModuleList()
+ for i in range(self.start_level + 1, self.backbone_end_level):
+ d_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ stride=2,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ pafpn_conv = ConvModule(
+ out_channels,
+ out_channels,
+ 3,
+ padding=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ inplace=False)
+ self.downsample_convs.append(d_conv)
+ self.pafpn_convs.append(pafpn_conv)
+
+ @auto_fp16()
+ def forward(self, inputs):
+ """Forward function."""
+ assert len(inputs) == len(self.in_channels)
+
+ # build laterals
+ laterals = [
+ lateral_conv(inputs[i + self.start_level])
+ for i, lateral_conv in enumerate(self.lateral_convs)
+ ]
+
+ # build top-down path
+ used_backbone_levels = len(laterals)
+ for i in range(used_backbone_levels - 1, 0, -1):
+ prev_shape = laterals[i - 1].shape[2:]
+ laterals[i - 1] += F.interpolate(
+ laterals[i], size=prev_shape, mode='nearest')
+
+ # build outputs
+ # part 1: from original levels
+ inter_outs = [
+ self.fpn_convs[i](laterals[i]) for i in range(used_backbone_levels)
+ ]
+
+ # part 2: add bottom-up path
+ for i in range(0, used_backbone_levels - 1):
+ inter_outs[i + 1] += self.downsample_convs[i](inter_outs[i])
+
+ outs = []
+ outs.append(inter_outs[0])
+ outs.extend([
+ self.pafpn_convs[i - 1](inter_outs[i])
+ for i in range(1, used_backbone_levels)
+ ])
+
+ # part 3: add extra levels
+ if self.num_outs > len(outs):
+ # use max pool to get more levels on top of outputs
+ # (e.g., Faster R-CNN, Mask R-CNN)
+ if not self.add_extra_convs:
+ for i in range(self.num_outs - used_backbone_levels):
+ outs.append(F.max_pool2d(outs[-1], 1, stride=2))
+ # add conv layers on top of original feature maps (RetinaNet)
+ else:
+ if self.add_extra_convs == 'on_input':
+ orig = inputs[self.backbone_end_level - 1]
+ outs.append(self.fpn_convs[used_backbone_levels](orig))
+ elif self.add_extra_convs == 'on_lateral':
+ outs.append(self.fpn_convs[used_backbone_levels](
+ laterals[-1]))
+ elif self.add_extra_convs == 'on_output':
+ outs.append(self.fpn_convs[used_backbone_levels](outs[-1]))
+ else:
+ raise NotImplementedError
+ for i in range(used_backbone_levels + 1, self.num_outs):
+ if self.relu_before_extra_convs:
+ outs.append(self.fpn_convs[i](F.relu(outs[-1])))
+ else:
+ outs.append(self.fpn_convs[i](outs[-1]))
+ return tuple(outs)
diff --git a/mmdet/models/necks/rfp.py b/mmdet/models/necks/rfp.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a63e63bdef0094c26c17526d5ddde75bd309cea
--- /dev/null
+++ b/mmdet/models/necks/rfp.py
@@ -0,0 +1,128 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import constant_init, kaiming_init, xavier_init
+
+from ..builder import NECKS, build_backbone
+from .fpn import FPN
+
+
+class ASPP(nn.Module):
+ """ASPP (Atrous Spatial Pyramid Pooling)
+
+ This is an implementation of the ASPP module used in DetectoRS
+ (https://arxiv.org/pdf/2006.02334.pdf)
+
+ Args:
+ in_channels (int): Number of input channels.
+ out_channels (int): Number of channels produced by this module
+ dilations (tuple[int]): Dilations of the four branches.
+ Default: (1, 3, 6, 1)
+ """
+
+ def __init__(self, in_channels, out_channels, dilations=(1, 3, 6, 1)):
+ super().__init__()
+ assert dilations[-1] == 1
+ self.aspp = nn.ModuleList()
+ for dilation in dilations:
+ kernel_size = 3 if dilation > 1 else 1
+ padding = dilation if dilation > 1 else 0
+ conv = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size=kernel_size,
+ stride=1,
+ dilation=dilation,
+ padding=padding,
+ bias=True)
+ self.aspp.append(conv)
+ self.gap = nn.AdaptiveAvgPool2d(1)
+ self.init_weights()
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+
+ def forward(self, x):
+ avg_x = self.gap(x)
+ out = []
+ for aspp_idx in range(len(self.aspp)):
+ inp = avg_x if (aspp_idx == len(self.aspp) - 1) else x
+ out.append(F.relu_(self.aspp[aspp_idx](inp)))
+ out[-1] = out[-1].expand_as(out[-2])
+ out = torch.cat(out, dim=1)
+ return out
+
+
+@NECKS.register_module()
+class RFP(FPN):
+ """RFP (Recursive Feature Pyramid)
+
+ This is an implementation of RFP in `DetectoRS
+ `_. Different from standard FPN, the
+ input of RFP should be multi level features along with origin input image
+ of backbone.
+
+ Args:
+ rfp_steps (int): Number of unrolled steps of RFP.
+ rfp_backbone (dict): Configuration of the backbone for RFP.
+ aspp_out_channels (int): Number of output channels of ASPP module.
+ aspp_dilations (tuple[int]): Dilation rates of four branches.
+ Default: (1, 3, 6, 1)
+ """
+
+ def __init__(self,
+ rfp_steps,
+ rfp_backbone,
+ aspp_out_channels,
+ aspp_dilations=(1, 3, 6, 1),
+ **kwargs):
+ super().__init__(**kwargs)
+ self.rfp_steps = rfp_steps
+ self.rfp_modules = nn.ModuleList()
+ for rfp_idx in range(1, rfp_steps):
+ rfp_module = build_backbone(rfp_backbone)
+ self.rfp_modules.append(rfp_module)
+ self.rfp_aspp = ASPP(self.out_channels, aspp_out_channels,
+ aspp_dilations)
+ self.rfp_weight = nn.Conv2d(
+ self.out_channels,
+ 1,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ bias=True)
+
+ def init_weights(self):
+ # Avoid using super().init_weights(), which may alter the default
+ # initialization of the modules in self.rfp_modules that have missing
+ # keys in the pretrained checkpoint.
+ for convs in [self.lateral_convs, self.fpn_convs]:
+ for m in convs.modules():
+ if isinstance(m, nn.Conv2d):
+ xavier_init(m, distribution='uniform')
+ for rfp_idx in range(self.rfp_steps - 1):
+ self.rfp_modules[rfp_idx].init_weights(
+ self.rfp_modules[rfp_idx].pretrained)
+ constant_init(self.rfp_weight, 0)
+
+ def forward(self, inputs):
+ inputs = list(inputs)
+ assert len(inputs) == len(self.in_channels) + 1 # +1 for input image
+ img = inputs.pop(0)
+ # FPN forward
+ x = super().forward(tuple(inputs))
+ for rfp_idx in range(self.rfp_steps - 1):
+ rfp_feats = [x[0]] + list(
+ self.rfp_aspp(x[i]) for i in range(1, len(x)))
+ x_idx = self.rfp_modules[rfp_idx].rfp_forward(img, rfp_feats)
+ # FPN forward
+ x_idx = super().forward(x_idx)
+ x_new = []
+ for ft_idx in range(len(x_idx)):
+ add_weight = torch.sigmoid(self.rfp_weight(x_idx[ft_idx]))
+ x_new.append(add_weight * x_idx[ft_idx] +
+ (1 - add_weight) * x[ft_idx])
+ x = x_new
+ return x
diff --git a/mmdet/models/necks/yolo_neck.py b/mmdet/models/necks/yolo_neck.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2f9b9ef3859796c284c16ad1a92fe41ecbed613
--- /dev/null
+++ b/mmdet/models/necks/yolo_neck.py
@@ -0,0 +1,136 @@
+# Copyright (c) 2019 Western Digital Corporation or its affiliates.
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule
+
+from ..builder import NECKS
+
+
+class DetectionBlock(nn.Module):
+ """Detection block in YOLO neck.
+
+ Let out_channels = n, the DetectionBlock contains:
+ Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer.
+ The first 6 ConvLayers are formed the following way:
+ 1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n.
+ The Conv2D layer is 1x1x255.
+ Some block will have branch after the fifth ConvLayer.
+ The input channel is arbitrary (in_channels)
+
+ Args:
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
+ super(DetectionBlock, self).__init__()
+ double_out_channels = out_channels * 2
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+ self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg)
+ self.conv2 = ConvModule(
+ out_channels, double_out_channels, 3, padding=1, **cfg)
+ self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg)
+ self.conv4 = ConvModule(
+ out_channels, double_out_channels, 3, padding=1, **cfg)
+ self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg)
+
+ def forward(self, x):
+ tmp = self.conv1(x)
+ tmp = self.conv2(tmp)
+ tmp = self.conv3(tmp)
+ tmp = self.conv4(tmp)
+ out = self.conv5(tmp)
+ return out
+
+
+@NECKS.register_module()
+class YOLOV3Neck(nn.Module):
+ """The neck of YOLOV3.
+
+ It can be treated as a simplified version of FPN. It
+ will take the result from Darknet backbone and do some upsampling and
+ concatenation. It will finally output the detection result.
+
+ Note:
+ The input feats should be from top to bottom.
+ i.e., from high-lvl to low-lvl
+ But YOLOV3Neck will process them in reversed order.
+ i.e., from bottom (high-lvl) to top (low-lvl)
+
+ Args:
+ num_scales (int): The number of scales / stages.
+ in_channels (int): The number of input channels.
+ out_channels (int): The number of output channels.
+ conv_cfg (dict): Config dict for convolution layer. Default: None.
+ norm_cfg (dict): Dictionary to construct and config norm layer.
+ Default: dict(type='BN', requires_grad=True)
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='LeakyReLU', negative_slope=0.1).
+ """
+
+ def __init__(self,
+ num_scales,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN', requires_grad=True),
+ act_cfg=dict(type='LeakyReLU', negative_slope=0.1)):
+ super(YOLOV3Neck, self).__init__()
+ assert (num_scales == len(in_channels) == len(out_channels))
+ self.num_scales = num_scales
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+
+ # shortcut
+ cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg)
+
+ # To support arbitrary scales, the code looks awful, but it works.
+ # Better solution is welcomed.
+ self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg)
+ for i in range(1, self.num_scales):
+ in_c, out_c = self.in_channels[i], self.out_channels[i]
+ self.add_module(f'conv{i}', ConvModule(in_c, out_c, 1, **cfg))
+ # in_c + out_c : High-lvl feats will be cat with low-lvl feats
+ self.add_module(f'detect{i+1}',
+ DetectionBlock(in_c + out_c, out_c, **cfg))
+
+ def forward(self, feats):
+ assert len(feats) == self.num_scales
+
+ # processed from bottom (high-lvl) to top (low-lvl)
+ outs = []
+ out = self.detect1(feats[-1])
+ outs.append(out)
+
+ for i, x in enumerate(reversed(feats[:-1])):
+ conv = getattr(self, f'conv{i+1}')
+ tmp = conv(out)
+
+ # Cat with low-lvl feats
+ tmp = F.interpolate(tmp, scale_factor=2)
+ tmp = torch.cat((tmp, x), 1)
+
+ detect = getattr(self, f'detect{i+2}')
+ out = detect(tmp)
+ outs.append(out)
+
+ return tuple(outs)
+
+ def init_weights(self):
+ """Initialize the weights of module."""
+ # init is done in ConvModule
+ pass
diff --git a/mmdet/models/roi_heads/__init__.py b/mmdet/models/roi_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a3e515e40ffa26f83381342952ea9a0e1ccc235
--- /dev/null
+++ b/mmdet/models/roi_heads/__init__.py
@@ -0,0 +1,43 @@
+'''
+from .base_roi_head import BaseRoIHead
+from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
+ SCNetBBoxHead, Shared2FCBBoxHead,
+ Shared4Conv1FCBBoxHead)
+from .cascade_roi_head import CascadeRoIHead
+from .double_roi_head import DoubleHeadRoIHead
+from .dynamic_roi_head import DynamicRoIHead
+from .grid_roi_head import GridRoIHead
+from .htc_roi_head import HybridTaskCascadeRoIHead
+from .mask_heads import (CoarseMaskHead, FCNMaskHead, FeatureRelayHead,
+ FusedSemanticHead, GlobalContextHead, GridHead,
+ HTCMaskHead, MaskIoUHead, MaskPointHead,
+ SCNetMaskHead, SCNetSemanticHead)
+from .mask_scoring_roi_head import MaskScoringRoIHead
+from .pisa_roi_head import PISARoIHead
+from .point_rend_roi_head import PointRendRoIHead
+from .roi_extractors import SingleRoIExtractor
+from .scnet_roi_head import SCNetRoIHead
+from .shared_heads import ResLayer
+from .sparse_roi_head import SparseRoIHead
+from .standard_roi_head import StandardRoIHead
+from .trident_roi_head import TridentRoIHead
+
+__all__ = [
+ 'BaseRoIHead', 'CascadeRoIHead', 'DoubleHeadRoIHead', 'MaskScoringRoIHead',
+ 'HybridTaskCascadeRoIHead', 'GridRoIHead', 'ResLayer', 'BBoxHead',
+ 'ConvFCBBoxHead', 'Shared2FCBBoxHead', 'StandardRoIHead',
+ 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'FCNMaskHead',
+ 'HTCMaskHead', 'FusedSemanticHead', 'GridHead', 'MaskIoUHead',
+ 'SingleRoIExtractor', 'PISARoIHead', 'PointRendRoIHead', 'MaskPointHead',
+ 'CoarseMaskHead', 'DynamicRoIHead', 'SparseRoIHead', 'TridentRoIHead',
+ 'SCNetRoIHead', 'SCNetMaskHead', 'SCNetSemanticHead', 'SCNetBBoxHead',
+ 'FeatureRelayHead', 'GlobalContextHead'
+]
+'''
+from .bbox_heads import (BBoxHead, ConvFCBBoxHead, DoubleConvFCBBoxHead,
+ SCNetBBoxHead, Shared2FCBBoxHead,
+ Shared4Conv1FCBBoxHead)
+from .standard_roi_head import StandardRoIHead
+from .roi_extractors import SingleRoIExtractor
+from .mask_heads import FCNMaskHead
+__all__ = ['BBoxHead','StandardRoIHead','SingleRoIExtractor','Shared2FCBBoxHead','FCNMaskHead']
diff --git a/mmdet/models/roi_heads/base_roi_head.py b/mmdet/models/roi_heads/base_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b560a913a52f7e2d6ee5a8f85589fa07e92118b
--- /dev/null
+++ b/mmdet/models/roi_heads/base_roi_head.py
@@ -0,0 +1,114 @@
+from abc import ABCMeta, abstractmethod
+
+import torch.nn as nn
+
+from ..builder import build_shared_head
+
+
+class BaseRoIHead(nn.Module, metaclass=ABCMeta):
+ """Base class for RoIHeads."""
+
+ def __init__(self,
+ bbox_roi_extractor=None,
+ bbox_head=None,
+ mask_roi_extractor=None,
+ mask_head=None,
+ gan_roi_extractor=None,
+ gan_head=None,
+ shared_head=None,
+ train_cfg=None,
+ test_cfg=None):
+ super(BaseRoIHead, self).__init__()
+ self.train_cfg = train_cfg
+ self.test_cfg = test_cfg
+ if shared_head is not None:
+ self.shared_head = build_shared_head(shared_head)
+
+ if bbox_head is not None:
+ self.init_bbox_head(bbox_roi_extractor, bbox_head)
+
+ if mask_head is not None:
+ self.init_mask_head(mask_roi_extractor, mask_head)
+
+ if gan_head is not None:
+ self.init_gan_head(mask_roi_extractor, mask_head)
+
+ self.init_assigner_sampler()
+
+ @property
+ def with_bbox(self):
+ """bool: whether the RoI head contains a `bbox_head`"""
+ return hasattr(self, 'bbox_head') and self.bbox_head is not None
+
+ @property
+ def with_mask(self):
+ """bool: whether the RoI head contains a `mask_head`"""
+ return hasattr(self, 'mask_head') and self.mask_head is not None
+
+ @property
+ def with_shared_head(self):
+ """bool: whether the RoI head contains a `shared_head`"""
+ return hasattr(self, 'shared_head') and self.shared_head is not None
+
+ @abstractmethod
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ pass
+
+ @abstractmethod
+ def init_bbox_head(self):
+ """Initialize ``bbox_head``"""
+ pass
+
+ @abstractmethod
+ def init_mask_head(self):
+ """Initialize ``mask_head``"""
+ pass
+
+ @abstractmethod
+ def init_gan_head(self):
+ """Initialize ``gan_head``"""
+ pass
+
+
+ @abstractmethod
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler."""
+ pass
+
+ @abstractmethod
+ def forward_train(self,
+ x,
+ img_meta,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ **kwargs):
+ """Forward function during training."""
+
+ async def async_simple_test(self, x, img_meta, **kwargs):
+ """Asynchronized test function."""
+ raise NotImplementedError
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_meta,
+ proposals=None,
+ rescale=False,
+ **kwargs):
+ """Test without augmentation."""
+
+ def aug_test(self, x, proposal_list, img_metas, rescale=False, **kwargs):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
diff --git a/mmdet/models/roi_heads/bbox_heads/__init__.py b/mmdet/models/roi_heads/bbox_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc5d29ece5bbf2f168f538f151f06d1b263a5153
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/__init__.py
@@ -0,0 +1,13 @@
+from .bbox_head import BBoxHead
+from .convfc_bbox_head import (ConvFCBBoxHead, Shared2FCBBoxHead,
+ Shared4Conv1FCBBoxHead)
+from .dii_head import DIIHead
+from .double_bbox_head import DoubleConvFCBBoxHead
+from .sabl_head import SABLHead
+from .scnet_bbox_head import SCNetBBoxHead
+
+__all__ = [
+ 'BBoxHead', 'ConvFCBBoxHead', 'Shared2FCBBoxHead',
+ 'Shared4Conv1FCBBoxHead', 'DoubleConvFCBBoxHead', 'SABLHead', 'DIIHead',
+ 'SCNetBBoxHead'
+]
diff --git a/mmdet/models/roi_heads/bbox_heads/bbox_head.py b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..408abef3a244115b4e73748049a228e37ad0665c
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/bbox_head.py
@@ -0,0 +1,483 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.runner import auto_fp16, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.losses import accuracy
+
+
+@HEADS.register_module()
+class BBoxHead(nn.Module):
+ """Simplest RoI head, with only two fc layers for classification and
+ regression respectively."""
+
+ def __init__(self,
+ with_avg_pool=False,
+ with_cls=True,
+ with_reg=True,
+ roi_feat_size=7,
+ in_channels=256,
+ num_classes=80,
+ bbox_coder=dict(
+ type='DeltaXYWHBBoxCoder',
+ clip_border=True,
+ target_means=[0., 0., 0., 0.],
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
+ reg_class_agnostic=False,
+ reg_decoded_bbox=False,
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox=dict(
+ type='SmoothL1Loss', beta=1.0, loss_weight=1.0)):
+ super(BBoxHead, self).__init__()
+ assert with_cls or with_reg
+ self.with_avg_pool = with_avg_pool
+ self.with_cls = with_cls
+ self.with_reg = with_reg
+ self.roi_feat_size = _pair(roi_feat_size)
+ self.roi_feat_area = self.roi_feat_size[0] * self.roi_feat_size[1]
+ self.in_channels = in_channels
+ self.num_classes = num_classes
+ self.reg_class_agnostic = reg_class_agnostic
+ self.reg_decoded_bbox = reg_decoded_bbox
+ self.fp16_enabled = False
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox = build_loss(loss_bbox)
+
+ in_channels = self.in_channels
+ if self.with_avg_pool:
+ self.avg_pool = nn.AvgPool2d(self.roi_feat_size)
+ else:
+ in_channels *= self.roi_feat_area
+ if self.with_cls:
+ # need to add background class
+ self.fc_cls = nn.Linear(in_channels, num_classes + 1)
+ if self.with_reg:
+ out_dim_reg = 4 if reg_class_agnostic else 4 * num_classes
+ self.fc_reg = nn.Linear(in_channels, out_dim_reg)
+ self.debug_imgs = None
+
+ def init_weights(self):
+ # conv layers are already initialized by ConvModule
+ if self.with_cls:
+ nn.init.normal_(self.fc_cls.weight, 0, 0.01)
+ nn.init.constant_(self.fc_cls.bias, 0)
+ if self.with_reg:
+ nn.init.normal_(self.fc_reg.weight, 0, 0.001)
+ nn.init.constant_(self.fc_reg.bias, 0)
+
+ @auto_fp16()
+ def forward(self, x):
+ if self.with_avg_pool:
+ x = self.avg_pool(x)
+ x = x.view(x.size(0), -1)
+ cls_score = self.fc_cls(x) if self.with_cls else None
+ bbox_pred = self.fc_reg(x) if self.with_reg else None
+ return cls_score, bbox_pred
+
+ def _get_target_single(self, pos_bboxes, neg_bboxes, pos_gt_bboxes,
+ pos_gt_labels, cfg):
+ """Calculate the ground truth for proposals in the single image
+ according to the sampling results.
+
+ Args:
+ pos_bboxes (Tensor): Contains all the positive boxes,
+ has shape (num_pos, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ neg_bboxes (Tensor): Contains all the negative boxes,
+ has shape (num_neg, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_bboxes (Tensor): Contains all the gt_boxes,
+ has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_labels (Tensor): Contains all the gt_labels,
+ has shape (num_gt).
+ cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals
+ in a single image. Containing the following Tensors:
+
+ - labels(Tensor): Gt_labels for all proposals, has
+ shape (num_proposals,).
+ - label_weights(Tensor): Labels_weights for all
+ proposals, has shape (num_proposals,).
+ - bbox_targets(Tensor):Regression target for all
+ proposals, has shape (num_proposals, 4), the
+ last dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights(Tensor):Regression weights for all
+ proposals, has shape (num_proposals, 4).
+ """
+ num_pos = pos_bboxes.size(0)
+ num_neg = neg_bboxes.size(0)
+ num_samples = num_pos + num_neg
+
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_bboxes.new_zeros(num_samples)
+ bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
+ bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
+ if num_pos > 0:
+ labels[:num_pos] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[:num_pos] = pos_weight
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ pos_bboxes, pos_gt_bboxes)
+ else:
+ # When the regression loss (e.g. `IouLoss`, `GIouLoss`)
+ # is applied directly on the decoded bounding boxes, both
+ # the predicted boxes and regression targets should be with
+ # absolute coordinate format.
+ pos_bbox_targets = pos_gt_bboxes
+ bbox_targets[:num_pos, :] = pos_bbox_targets
+ bbox_weights[:num_pos, :] = 1
+ if num_neg > 0:
+ label_weights[-num_neg:] = 1.0
+
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_targets(self,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ concat=True):
+ """Calculate the ground truth for all samples in a batch according to
+ the sampling_results.
+
+ Almost the same as the implementation in bbox_head, we passed
+ additional parameters pos_inds_list and neg_inds_list to
+ `_get_target_single` function.
+
+ Args:
+ sampling_results (List[obj:SamplingResults]): Assign results of
+ all images in a batch after sampling.
+ gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
+ each tensor has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ gt_labels (list[Tensor]): Gt_labels of all images in a batch,
+ each tensor has shape (num_gt,).
+ rcnn_train_cfg (obj:ConfigDict): `train_cfg` of RCNN.
+ concat (bool): Whether to concatenate the results of all
+ the images in a single batch.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following list of Tensors:
+
+ - labels (list[Tensor],Tensor): Gt_labels for all
+ proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals,).
+ - label_weights (list[Tensor]): Labels_weights for
+ all proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals,).
+ - bbox_targets (list[Tensor],Tensor): Regression target
+ for all proposals in a batch, each tensor in list
+ has shape (num_proposals, 4) when `concat=False`,
+ otherwise just a single tensor has shape
+ (num_all_proposals, 4), the last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ - bbox_weights (list[tensor],Tensor): Regression weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals, 4) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals, 4).
+ """
+ pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
+ neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ labels, label_weights, bbox_targets, bbox_weights = multi_apply(
+ self._get_target_single,
+ pos_bboxes_list,
+ neg_bboxes_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bbox_targets = torch.cat(bbox_targets, 0)
+ bbox_weights = torch.cat(bbox_weights, 0)
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ rois,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ reduction_override=None):
+ losses = dict()
+ if cls_score is not None:
+ avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
+ if cls_score.numel() > 0:
+ losses['loss_cls'] = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['acc'] = accuracy(cls_score, labels)
+ if bbox_pred is not None:
+ bg_class_ind = self.num_classes
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ # do not perform bounding box regression for BG anymore.
+ if pos_inds.any():
+ if self.reg_decoded_bbox:
+ # When the regression loss (e.g. `IouLoss`,
+ # `GIouLoss`, `DIouLoss`) is applied directly on
+ # the decoded bounding boxes, it decodes the
+ # already encoded coordinates to absolute format.
+ bbox_pred = self.bbox_coder.decode(rois[:, 1:], bbox_pred)
+ if self.reg_class_agnostic:
+ pos_bbox_pred = bbox_pred.view(
+ bbox_pred.size(0), 4)[pos_inds.type(torch.bool)]
+ else:
+ pos_bbox_pred = bbox_pred.view(
+ bbox_pred.size(0), -1,
+ 4)[pos_inds.type(torch.bool),
+ labels[pos_inds.type(torch.bool)]]
+ losses['loss_bbox'] = self.loss_bbox(
+ pos_bbox_pred,
+ bbox_targets[pos_inds.type(torch.bool)],
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=bbox_targets.size(0),
+ reduction_override=reduction_override)
+ else:
+ losses['loss_bbox'] = bbox_pred[pos_inds].sum()
+ return losses
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def get_bboxes(self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None):
+ """Transform network output for a batch into bbox predictions.
+
+ If the input rois has batch dimension, the function would be in
+ `batch_mode` and return is a tuple[list[Tensor], list[Tensor]],
+ otherwise, the return is a tuple[Tensor, Tensor].
+
+ Args:
+ rois (Tensor): Boxes to be transformed. Has shape (num_boxes, 5)
+ or (B, num_boxes, 5)
+ cls_score (list[Tensor] or Tensor): Box scores for
+ each scale level, each is a 4D-tensor, the channel number is
+ num_points * num_classes.
+ bbox_pred (Tensor, optional): Box energies / deltas for each scale
+ level, each is a 4D-tensor, the channel number is
+ num_classes * 4.
+ img_shape (Sequence[int] or torch.Tensor or Sequence[
+ Sequence[int]], optional): Maximum bounds for boxes, specifies
+ (H, W, C) or (H, W). If rois shape is (B, num_boxes, 4), then
+ the max_shape should be a Sequence[Sequence[int]]
+ and the length of max_shape should also be B.
+ scale_factor (tuple[ndarray] or ndarray): Scale factor of the
+ image arange as (w_scale, h_scale, w_scale, h_scale). In
+ `batch_mode`, the scale_factor shape is tuple[ndarray].
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+ cfg (obj:`ConfigDict`): `test_cfg` of Bbox Head. Default: None
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]] or tuple[Tensor, Tensor]:
+ If the input has a batch dimension, the return value is
+ a tuple of the list. The first list contains the boxes of
+ the corresponding image in a batch, each tensor has the
+ shape (num_boxes, 5) and last dimension 5 represent
+ (tl_x, tl_y, br_x, br_y, score). Each Tensor in the second
+ list is the labels with shape (num_boxes, ). The length of
+ both lists should be equal to batch_size. Otherwise return
+ value is a tuple of two tensors, the first tensor is the
+ boxes with scores, the second tensor is the labels, both
+ have the same shape as the first case.
+ """
+ if isinstance(cls_score, list):
+ cls_score = sum(cls_score) / float(len(cls_score))
+
+ scores = F.softmax(
+ cls_score, dim=-1) if cls_score is not None else None
+
+ batch_mode = True
+ if rois.ndim == 2:
+ # e.g. AugTest, Cascade R-CNN, HTC, SCNet...
+ batch_mode = False
+
+ # add batch dimension
+ if scores is not None:
+ scores = scores.unsqueeze(0)
+ if bbox_pred is not None:
+ bbox_pred = bbox_pred.unsqueeze(0)
+ rois = rois.unsqueeze(0)
+
+ if bbox_pred is not None:
+ bboxes = self.bbox_coder.decode(
+ rois[..., 1:], bbox_pred, max_shape=img_shape)
+ else:
+ bboxes = rois[..., 1:].clone()
+ if img_shape is not None:
+ max_shape = bboxes.new_tensor(img_shape)[..., :2]
+ min_xy = bboxes.new_tensor(0)
+ max_xy = torch.cat(
+ [max_shape] * 2, dim=-1).flip(-1).unsqueeze(-2)
+ bboxes = torch.where(bboxes < min_xy, min_xy, bboxes)
+ bboxes = torch.where(bboxes > max_xy, max_xy, bboxes)
+
+ if rescale and bboxes.size(-2) > 0:
+ if not isinstance(scale_factor, tuple):
+ scale_factor = tuple([scale_factor])
+ # B, 1, bboxes.size(-1)
+ scale_factor = bboxes.new_tensor(scale_factor).unsqueeze(1).repeat(
+ 1, 1,
+ bboxes.size(-1) // 4)
+ bboxes /= scale_factor
+
+ det_bboxes = []
+ det_labels = []
+ for (bbox, score) in zip(bboxes, scores):
+ if cfg is not None:
+ det_bbox, det_label = multiclass_nms(bbox, score,
+ cfg.score_thr, cfg.nms,
+ cfg.max_per_img)
+ else:
+ det_bbox, det_label = bbox, score
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+
+ if not batch_mode:
+ det_bboxes = det_bboxes[0]
+ det_labels = det_labels[0]
+ return det_bboxes, det_labels
+
+ @force_fp32(apply_to=('bbox_preds', ))
+ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
+ """Refine bboxes during training.
+
+ Args:
+ rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
+ and bs is the sampled RoIs per image. The first column is
+ the image id and the next 4 columns are x1, y1, x2, y2.
+ labels (Tensor): Shape (n*bs, ).
+ bbox_preds (Tensor): Shape (n*bs, 4) or (n*bs, 4*#class).
+ pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
+ is a gt bbox.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Refined bboxes of each image in a mini-batch.
+
+ Example:
+ >>> # xdoctest: +REQUIRES(module:kwarray)
+ >>> import kwarray
+ >>> import numpy as np
+ >>> from mmdet.core.bbox.demodata import random_boxes
+ >>> self = BBoxHead(reg_class_agnostic=True)
+ >>> n_roi = 2
+ >>> n_img = 4
+ >>> scale = 512
+ >>> rng = np.random.RandomState(0)
+ >>> img_metas = [{'img_shape': (scale, scale)}
+ ... for _ in range(n_img)]
+ >>> # Create rois in the expected format
+ >>> roi_boxes = random_boxes(n_roi, scale=scale, rng=rng)
+ >>> img_ids = torch.randint(0, n_img, (n_roi,))
+ >>> img_ids = img_ids.float()
+ >>> rois = torch.cat([img_ids[:, None], roi_boxes], dim=1)
+ >>> # Create other args
+ >>> labels = torch.randint(0, 2, (n_roi,)).long()
+ >>> bbox_preds = random_boxes(n_roi, scale=scale, rng=rng)
+ >>> # For each image, pretend random positive boxes are gts
+ >>> is_label_pos = (labels.numpy() > 0).astype(np.int)
+ >>> lbl_per_img = kwarray.group_items(is_label_pos,
+ ... img_ids.numpy())
+ >>> pos_per_img = [sum(lbl_per_img.get(gid, []))
+ ... for gid in range(n_img)]
+ >>> pos_is_gts = [
+ >>> torch.randint(0, 2, (npos,)).byte().sort(
+ >>> descending=True)[0]
+ >>> for npos in pos_per_img
+ >>> ]
+ >>> bboxes_list = self.refine_bboxes(rois, labels, bbox_preds,
+ >>> pos_is_gts, img_metas)
+ >>> print(bboxes_list)
+ """
+ img_ids = rois[:, 0].long().unique(sorted=True)
+ assert img_ids.numel() <= len(img_metas)
+
+ bboxes_list = []
+ for i in range(len(img_metas)):
+ inds = torch.nonzero(
+ rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
+ num_rois = inds.numel()
+
+ bboxes_ = rois[inds, 1:]
+ label_ = labels[inds]
+ bbox_pred_ = bbox_preds[inds]
+ img_meta_ = img_metas[i]
+ pos_is_gts_ = pos_is_gts[i]
+
+ bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
+ img_meta_)
+
+ # filter gt bboxes
+ pos_keep = 1 - pos_is_gts_
+ keep_inds = pos_is_gts_.new_ones(num_rois)
+ keep_inds[:len(pos_is_gts_)] = pos_keep
+
+ bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('bbox_pred', ))
+ def regress_by_class(self, rois, label, bbox_pred, img_meta):
+ """Regress the bbox for the predicted class. Used in Cascade R-CNN.
+
+ Args:
+ rois (Tensor): shape (n, 4) or (n, 5)
+ label (Tensor): shape (n, )
+ bbox_pred (Tensor): shape (n, 4*(#class)) or (n, 4)
+ img_meta (dict): Image meta info.
+
+ Returns:
+ Tensor: Regressed bboxes, the same shape as input rois.
+ """
+ assert rois.size(1) == 4 or rois.size(1) == 5, repr(rois.shape)
+
+ if not self.reg_class_agnostic:
+ label = label * 4
+ inds = torch.stack((label, label + 1, label + 2, label + 3), 1)
+ bbox_pred = torch.gather(bbox_pred, 1, inds)
+ assert bbox_pred.size(1) == 4
+
+ if rois.size(1) == 4:
+ new_rois = self.bbox_coder.decode(
+ rois, bbox_pred, max_shape=img_meta['img_shape'])
+ else:
+ bboxes = self.bbox_coder.decode(
+ rois[:, 1:], bbox_pred, max_shape=img_meta['img_shape'])
+ new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
+
+ return new_rois
diff --git a/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e86d2ea67e154fae18dbf9d2bfde6d0a70e582c
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/convfc_bbox_head.py
@@ -0,0 +1,205 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+
+from mmdet.models.builder import HEADS
+from .bbox_head import BBoxHead
+
+
+@HEADS.register_module()
+class ConvFCBBoxHead(BBoxHead):
+ r"""More general bbox head, with shared conv and fc layers and two optional
+ separated branches.
+
+ .. code-block:: none
+
+ /-> cls convs -> cls fcs -> cls
+ shared convs -> shared fcs
+ \-> reg convs -> reg fcs -> reg
+ """ # noqa: W605
+
+ def __init__(self,
+ num_shared_convs=0,
+ num_shared_fcs=0,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ conv_out_channels=256,
+ fc_out_channels=1024,
+ conv_cfg=None,
+ norm_cfg=None,
+ *args,
+ **kwargs):
+ super(ConvFCBBoxHead, self).__init__(*args, **kwargs)
+ assert (num_shared_convs + num_shared_fcs + num_cls_convs +
+ num_cls_fcs + num_reg_convs + num_reg_fcs > 0)
+ if num_cls_convs > 0 or num_reg_convs > 0:
+ assert num_shared_fcs == 0
+ if not self.with_cls:
+ assert num_cls_convs == 0 and num_cls_fcs == 0
+ if not self.with_reg:
+ assert num_reg_convs == 0 and num_reg_fcs == 0
+ self.num_shared_convs = num_shared_convs
+ self.num_shared_fcs = num_shared_fcs
+ self.num_cls_convs = num_cls_convs
+ self.num_cls_fcs = num_cls_fcs
+ self.num_reg_convs = num_reg_convs
+ self.num_reg_fcs = num_reg_fcs
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ # add shared convs and fcs
+ self.shared_convs, self.shared_fcs, last_layer_dim = \
+ self._add_conv_fc_branch(
+ self.num_shared_convs, self.num_shared_fcs, self.in_channels,
+ True)
+ self.shared_out_channels = last_layer_dim
+
+ # add cls specific branch
+ self.cls_convs, self.cls_fcs, self.cls_last_dim = \
+ self._add_conv_fc_branch(
+ self.num_cls_convs, self.num_cls_fcs, self.shared_out_channels)
+
+ # add reg specific branch
+ self.reg_convs, self.reg_fcs, self.reg_last_dim = \
+ self._add_conv_fc_branch(
+ self.num_reg_convs, self.num_reg_fcs, self.shared_out_channels)
+
+ if self.num_shared_fcs == 0 and not self.with_avg_pool:
+ if self.num_cls_fcs == 0:
+ self.cls_last_dim *= self.roi_feat_area
+ if self.num_reg_fcs == 0:
+ self.reg_last_dim *= self.roi_feat_area
+
+ self.relu = nn.ReLU(inplace=True)
+ # reconstruct fc_cls and fc_reg since input channels are changed
+ if self.with_cls:
+ self.fc_cls = nn.Linear(self.cls_last_dim, self.num_classes + 1)
+ if self.with_reg:
+ out_dim_reg = (4 if self.reg_class_agnostic else 4 *
+ self.num_classes)
+ self.fc_reg = nn.Linear(self.reg_last_dim, out_dim_reg)
+
+ def _add_conv_fc_branch(self,
+ num_branch_convs,
+ num_branch_fcs,
+ in_channels,
+ is_shared=False):
+ """Add shared or separable branch.
+
+ convs -> avg pool (optional) -> fcs
+ """
+ last_layer_dim = in_channels
+ # add branch specific conv layers
+ branch_convs = nn.ModuleList()
+ if num_branch_convs > 0:
+ for i in range(num_branch_convs):
+ conv_in_channels = (
+ last_layer_dim if i == 0 else self.conv_out_channels)
+ branch_convs.append(
+ ConvModule(
+ conv_in_channels,
+ self.conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ last_layer_dim = self.conv_out_channels
+ # add branch specific fc layers
+ branch_fcs = nn.ModuleList()
+ if num_branch_fcs > 0:
+ # for shared branch, only consider self.with_avg_pool
+ # for separated branches, also consider self.num_shared_fcs
+ if (is_shared
+ or self.num_shared_fcs == 0) and not self.with_avg_pool:
+ last_layer_dim *= self.roi_feat_area
+ for i in range(num_branch_fcs):
+ fc_in_channels = (
+ last_layer_dim if i == 0 else self.fc_out_channels)
+ branch_fcs.append(
+ nn.Linear(fc_in_channels, self.fc_out_channels))
+ last_layer_dim = self.fc_out_channels
+ return branch_convs, branch_fcs, last_layer_dim
+
+ def init_weights(self):
+ super(ConvFCBBoxHead, self).init_weights()
+ # conv layers are already initialized by ConvModule
+ for module_list in [self.shared_fcs, self.cls_fcs, self.reg_fcs]:
+ for m in module_list.modules():
+ if isinstance(m, nn.Linear):
+ nn.init.xavier_uniform_(m.weight)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ # shared part
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ x = conv(x)
+
+ if self.num_shared_fcs > 0:
+ if self.with_avg_pool:
+ x = self.avg_pool(x)
+
+ x = x.flatten(1)
+
+ for fc in self.shared_fcs:
+ x = self.relu(fc(x))
+ # separate branches
+ x_cls = x
+ x_reg = x
+
+ for conv in self.cls_convs:
+ x_cls = conv(x_cls)
+ if x_cls.dim() > 2:
+ if self.with_avg_pool:
+ x_cls = self.avg_pool(x_cls)
+ x_cls = x_cls.flatten(1)
+ for fc in self.cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+
+ for conv in self.reg_convs:
+ x_reg = conv(x_reg)
+ if x_reg.dim() > 2:
+ if self.with_avg_pool:
+ x_reg = self.avg_pool(x_reg)
+ x_reg = x_reg.flatten(1)
+ for fc in self.reg_fcs:
+ x_reg = self.relu(fc(x_reg))
+
+ cls_score = self.fc_cls(x_cls) if self.with_cls else None
+ bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
+ return cls_score, bbox_pred
+
+
+@HEADS.register_module()
+class Shared2FCBBoxHead(ConvFCBBoxHead):
+
+ def __init__(self, fc_out_channels=1024, *args, **kwargs):
+ super(Shared2FCBBoxHead, self).__init__(
+ num_shared_convs=0,
+ num_shared_fcs=2,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ fc_out_channels=fc_out_channels,
+ *args,
+ **kwargs)
+
+
+@HEADS.register_module()
+class Shared4Conv1FCBBoxHead(ConvFCBBoxHead):
+
+ def __init__(self, fc_out_channels=1024, *args, **kwargs):
+ super(Shared4Conv1FCBBoxHead, self).__init__(
+ num_shared_convs=4,
+ num_shared_fcs=1,
+ num_cls_convs=0,
+ num_cls_fcs=0,
+ num_reg_convs=0,
+ num_reg_fcs=0,
+ fc_out_channels=fc_out_channels,
+ *args,
+ **kwargs)
diff --git a/mmdet/models/roi_heads/bbox_heads/dii_head.py b/mmdet/models/roi_heads/bbox_heads/dii_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c970a78184672aaaa95edcdaecec03a26604390
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/dii_head.py
@@ -0,0 +1,415 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import (bias_init_with_prob, build_activation_layer,
+ build_norm_layer)
+from mmcv.runner import auto_fp16, force_fp32
+
+from mmdet.core import multi_apply
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.dense_heads.atss_head import reduce_mean
+from mmdet.models.losses import accuracy
+from mmdet.models.utils import FFN, MultiheadAttention, build_transformer
+from .bbox_head import BBoxHead
+
+
+@HEADS.register_module()
+class DIIHead(BBoxHead):
+ r"""Dynamic Instance Interactive Head for `Sparse R-CNN: End-to-End Object
+ Detection with Learnable Proposals `_
+
+ Args:
+ num_classes (int): Number of class in dataset.
+ Defaults to 80.
+ num_ffn_fcs (int): The number of fully-connected
+ layers in FFNs. Defaults to 2.
+ num_heads (int): The hidden dimension of FFNs.
+ Defaults to 8.
+ num_cls_fcs (int): The number of fully-connected
+ layers in classification subnet. Defaults to 1.
+ num_reg_fcs (int): The number of fully-connected
+ layers in regression subnet. Defaults to 3.
+ feedforward_channels (int): The hidden dimension
+ of FFNs. Defaults to 2048
+ in_channels (int): Hidden_channels of MultiheadAttention.
+ Defaults to 256.
+ dropout (float): Probability of drop the channel.
+ Defaults to 0.0
+ ffn_act_cfg (dict): The activation config for FFNs.
+ dynamic_conv_cfg (dict): The convolution config
+ for DynamicConv.
+ loss_iou (dict): The config for iou or giou loss.
+
+ """
+
+ def __init__(self,
+ num_classes=80,
+ num_ffn_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_reg_fcs=3,
+ feedforward_channels=2048,
+ in_channels=256,
+ dropout=0.0,
+ ffn_act_cfg=dict(type='ReLU', inplace=True),
+ dynamic_conv_cfg=dict(
+ type='DynamicConv',
+ in_channels=256,
+ feat_channels=64,
+ out_channels=256,
+ input_feat_shape=7,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')),
+ loss_iou=dict(type='GIoULoss', loss_weight=2.0),
+ **kwargs):
+ super(DIIHead, self).__init__(
+ num_classes=num_classes,
+ reg_decoded_bbox=True,
+ reg_class_agnostic=True,
+ **kwargs)
+ self.loss_iou = build_loss(loss_iou)
+ self.in_channels = in_channels
+ self.fp16_enabled = False
+ self.attention = MultiheadAttention(in_channels, num_heads, dropout)
+ self.attention_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.instance_interactive_conv = build_transformer(dynamic_conv_cfg)
+ self.instance_interactive_conv_dropout = nn.Dropout(dropout)
+ self.instance_interactive_conv_norm = build_norm_layer(
+ dict(type='LN'), in_channels)[1]
+
+ self.ffn = FFN(
+ in_channels,
+ feedforward_channels,
+ num_ffn_fcs,
+ act_cfg=ffn_act_cfg,
+ dropout=dropout)
+ self.ffn_norm = build_norm_layer(dict(type='LN'), in_channels)[1]
+
+ self.cls_fcs = nn.ModuleList()
+ for _ in range(num_cls_fcs):
+ self.cls_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.cls_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.cls_fcs.append(
+ build_activation_layer(dict(type='ReLU', inplace=True)))
+
+ # over load the self.fc_cls in BBoxHead
+ if self.loss_cls.use_sigmoid:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes)
+ else:
+ self.fc_cls = nn.Linear(in_channels, self.num_classes + 1)
+
+ self.reg_fcs = nn.ModuleList()
+ for _ in range(num_reg_fcs):
+ self.reg_fcs.append(
+ nn.Linear(in_channels, in_channels, bias=False))
+ self.reg_fcs.append(
+ build_norm_layer(dict(type='LN'), in_channels)[1])
+ self.reg_fcs.append(
+ build_activation_layer(dict(type='ReLU', inplace=True)))
+ # over load the self.fc_cls in BBoxHead
+ self.fc_reg = nn.Linear(in_channels, 4)
+
+ assert self.reg_class_agnostic, 'DIIHead only ' \
+ 'suppport `reg_class_agnostic=True` '
+ assert self.reg_decoded_bbox, 'DIIHead only ' \
+ 'suppport `reg_decoded_bbox=True`'
+
+ def init_weights(self):
+ """Use xavier initialization for all weight parameter and set
+ classification head bias as a specific value when use focal loss."""
+ for p in self.parameters():
+ if p.dim() > 1:
+ nn.init.xavier_uniform_(p)
+ else:
+ # adopt the default initialization for
+ # the weight and bias of the layer norm
+ pass
+ if self.loss_cls.use_sigmoid:
+ bias_init = bias_init_with_prob(0.01)
+ nn.init.constant_(self.fc_cls.bias, bias_init)
+
+ @auto_fp16()
+ def forward(self, roi_feat, proposal_feat):
+ """Forward function of Dynamic Instance Interactive Head.
+
+ Args:
+ roi_feat (Tensor): Roi-pooling features with shape
+ (batch_size*num_proposals, feature_dimensions,
+ pooling_h , pooling_w).
+ proposal_feat (Tensor): Intermediate feature get from
+ diihead in last stage, has shape
+ (batch_size, num_proposals, feature_dimensions)
+
+ Returns:
+ tuple[Tensor]: Usually a tuple of classification scores
+ and bbox prediction and a intermediate feature.
+
+ - cls_scores (Tensor): Classification scores for
+ all proposals, has shape
+ (batch_size, num_proposals, num_classes).
+ - bbox_preds (Tensor): Box energies / deltas for
+ all proposals, has shape
+ (batch_size, num_proposals, 4).
+ - obj_feat (Tensor): Object feature before classification
+ and regression subnet, has shape
+ (batch_size, num_proposal, feature_dimensions).
+ """
+ N, num_proposals = proposal_feat.shape[:2]
+
+ # Self attention
+ proposal_feat = proposal_feat.permute(1, 0, 2)
+ proposal_feat = self.attention_norm(self.attention(proposal_feat))
+
+ # instance interactive
+ proposal_feat = proposal_feat.permute(1, 0,
+ 2).reshape(-1, self.in_channels)
+ proposal_feat_iic = self.instance_interactive_conv(
+ proposal_feat, roi_feat)
+ proposal_feat = proposal_feat + self.instance_interactive_conv_dropout(
+ proposal_feat_iic)
+ obj_feat = self.instance_interactive_conv_norm(proposal_feat)
+
+ # FFN
+ obj_feat = self.ffn_norm(self.ffn(obj_feat))
+
+ cls_feat = obj_feat
+ reg_feat = obj_feat
+
+ for cls_layer in self.cls_fcs:
+ cls_feat = cls_layer(cls_feat)
+ for reg_layer in self.reg_fcs:
+ reg_feat = reg_layer(reg_feat)
+
+ cls_score = self.fc_cls(cls_feat).view(N, num_proposals, -1)
+ bbox_delta = self.fc_reg(reg_feat).view(N, num_proposals, -1)
+
+ return cls_score, bbox_delta, obj_feat.view(N, num_proposals, -1)
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ imgs_whwh=None,
+ reduction_override=None,
+ **kwargs):
+ """"Loss function of DIIHead, get loss of all images.
+
+ Args:
+ cls_score (Tensor): Classification prediction
+ results of all class, has shape
+ (batch_size * num_proposals_single_image, num_classes)
+ bbox_pred (Tensor): Regression prediction results,
+ has shape
+ (batch_size * num_proposals_single_image, 4), the last
+ dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ labels (Tensor): Label of each proposals, has shape
+ (batch_size * num_proposals_single_image
+ label_weights (Tensor): Classification loss
+ weight of each proposals, has shape
+ (batch_size * num_proposals_single_image
+ bbox_targets (Tensor): Regression targets of each
+ proposals, has shape
+ (batch_size * num_proposals_single_image, 4),
+ the last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ bbox_weights (Tensor): Regression loss weight of each
+ proposals's coordinate, has shape
+ (batch_size * num_proposals_single_image, 4),
+ imgs_whwh (Tensor): imgs_whwh (Tensor): Tensor with\
+ shape (batch_size, num_proposals, 4), the last
+ dimension means
+ [img_width,img_height, img_width, img_height].
+ reduction_override (str, optional): The reduction
+ method used to override the original reduction
+ method of the loss. Options are "none",
+ "mean" and "sum". Defaults to None,
+
+ Returns:
+ dict[str, Tensor]: Dictionary of loss components
+ """
+ losses = dict()
+ bg_class_ind = self.num_classes
+ # note in spare rcnn num_gt == num_pos
+ pos_inds = (labels >= 0) & (labels < bg_class_ind)
+ num_pos = pos_inds.sum().float()
+ avg_factor = reduce_mean(num_pos)
+ if cls_score is not None:
+ if cls_score.numel() > 0:
+ losses['loss_cls'] = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['pos_acc'] = accuracy(cls_score[pos_inds],
+ labels[pos_inds])
+ if bbox_pred is not None:
+ # 0~self.num_classes-1 are FG, self.num_classes is BG
+ # do not perform bounding box regression for BG anymore.
+ if pos_inds.any():
+ pos_bbox_pred = bbox_pred.reshape(bbox_pred.size(0),
+ 4)[pos_inds.type(torch.bool)]
+ imgs_whwh = imgs_whwh.reshape(bbox_pred.size(0),
+ 4)[pos_inds.type(torch.bool)]
+ losses['loss_bbox'] = self.loss_bbox(
+ pos_bbox_pred / imgs_whwh,
+ bbox_targets[pos_inds.type(torch.bool)] / imgs_whwh,
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=avg_factor)
+ losses['loss_iou'] = self.loss_iou(
+ pos_bbox_pred,
+ bbox_targets[pos_inds.type(torch.bool)],
+ bbox_weights[pos_inds.type(torch.bool)],
+ avg_factor=avg_factor)
+ else:
+ losses['loss_bbox'] = bbox_pred.sum() * 0
+ losses['loss_iou'] = bbox_pred.sum() * 0
+ return losses
+
+ def _get_target_single(self, pos_inds, neg_inds, pos_bboxes, neg_bboxes,
+ pos_gt_bboxes, pos_gt_labels, cfg):
+ """Calculate the ground truth for proposals in the single image
+ according to the sampling results.
+
+ Almost the same as the implementation in `bbox_head`,
+ we add pos_inds and neg_inds to select positive and
+ negative samples instead of selecting the first num_pos
+ as positive samples.
+
+ Args:
+ pos_inds (Tensor): The length is equal to the
+ positive sample numbers contain all index
+ of the positive sample in the origin proposal set.
+ neg_inds (Tensor): The length is equal to the
+ negative sample numbers contain all index
+ of the negative sample in the origin proposal set.
+ pos_bboxes (Tensor): Contains all the positive boxes,
+ has shape (num_pos, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ neg_bboxes (Tensor): Contains all the negative boxes,
+ has shape (num_neg, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_bboxes (Tensor): Contains all the gt_boxes,
+ has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ pos_gt_labels (Tensor): Contains all the gt_labels,
+ has shape (num_gt).
+ cfg (obj:`ConfigDict`): `train_cfg` of R-CNN.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following Tensors:
+
+ - labels(Tensor): Gt_labels for all proposals, has
+ shape (num_proposals,).
+ - label_weights(Tensor): Labels_weights for all proposals, has
+ shape (num_proposals,).
+ - bbox_targets(Tensor):Regression target for all proposals, has
+ shape (num_proposals, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights(Tensor):Regression weights for all proposals,
+ has shape (num_proposals, 4).
+ """
+ num_pos = pos_bboxes.size(0)
+ num_neg = neg_bboxes.size(0)
+ num_samples = num_pos + num_neg
+
+ # original implementation uses new_zeros since BG are set to be 0
+ # now use empty & fill because BG cat_id = num_classes,
+ # FG cat_id = [0, num_classes-1]
+ labels = pos_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_bboxes.new_zeros(num_samples)
+ bbox_targets = pos_bboxes.new_zeros(num_samples, 4)
+ bbox_weights = pos_bboxes.new_zeros(num_samples, 4)
+ if num_pos > 0:
+ labels[pos_inds] = pos_gt_labels
+ pos_weight = 1.0 if cfg.pos_weight <= 0 else cfg.pos_weight
+ label_weights[pos_inds] = pos_weight
+ if not self.reg_decoded_bbox:
+ pos_bbox_targets = self.bbox_coder.encode(
+ pos_bboxes, pos_gt_bboxes)
+ else:
+ pos_bbox_targets = pos_gt_bboxes
+ bbox_targets[pos_inds, :] = pos_bbox_targets
+ bbox_weights[pos_inds, :] = 1
+ if num_neg > 0:
+ label_weights[neg_inds] = 1.0
+
+ return labels, label_weights, bbox_targets, bbox_weights
+
+ def get_targets(self,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ concat=True):
+ """Calculate the ground truth for all samples in a batch according to
+ the sampling_results.
+
+ Almost the same as the implementation in bbox_head, we passed
+ additional parameters pos_inds_list and neg_inds_list to
+ `_get_target_single` function.
+
+ Args:
+ sampling_results (List[obj:SamplingResults]): Assign results of
+ all images in a batch after sampling.
+ gt_bboxes (list[Tensor]): Gt_bboxes of all images in a batch,
+ each tensor has shape (num_gt, 4), the last dimension 4
+ represents [tl_x, tl_y, br_x, br_y].
+ gt_labels (list[Tensor]): Gt_labels of all images in a batch,
+ each tensor has shape (num_gt,).
+ rcnn_train_cfg (obj:`ConfigDict`): `train_cfg` of RCNN.
+ concat (bool): Whether to concatenate the results of all
+ the images in a single batch.
+
+ Returns:
+ Tuple[Tensor]: Ground truth for proposals in a single image.
+ Containing the following list of Tensors:
+
+ - labels (list[Tensor],Tensor): Gt_labels for all
+ proposals in a batch, each tensor in list has
+ shape (num_proposals,) when `concat=False`, otherwise just
+ a single tensor has shape (num_all_proposals,).
+ - label_weights (list[Tensor]): Labels_weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals,) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals,).
+ - bbox_targets (list[Tensor],Tensor): Regression target
+ for all proposals in a batch, each tensor in list has
+ shape (num_proposals, 4) when `concat=False`, otherwise
+ just a single tensor has shape (num_all_proposals, 4),
+ the last dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ - bbox_weights (list[tensor],Tensor): Regression weights for
+ all proposals in a batch, each tensor in list has shape
+ (num_proposals, 4) when `concat=False`, otherwise just a
+ single tensor has shape (num_all_proposals, 4).
+ """
+ pos_inds_list = [res.pos_inds for res in sampling_results]
+ neg_inds_list = [res.neg_inds for res in sampling_results]
+ pos_bboxes_list = [res.pos_bboxes for res in sampling_results]
+ neg_bboxes_list = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes_list = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels_list = [res.pos_gt_labels for res in sampling_results]
+ labels, label_weights, bbox_targets, bbox_weights = multi_apply(
+ self._get_target_single,
+ pos_inds_list,
+ neg_inds_list,
+ pos_bboxes_list,
+ neg_bboxes_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bbox_targets = torch.cat(bbox_targets, 0)
+ bbox_weights = torch.cat(bbox_weights, 0)
+ return labels, label_weights, bbox_targets, bbox_weights
diff --git a/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c154cb3c0d9d7639c3d4a2a1272406d3fab8acd
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/double_bbox_head.py
@@ -0,0 +1,172 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init, xavier_init
+
+from mmdet.models.backbones.resnet import Bottleneck
+from mmdet.models.builder import HEADS
+from .bbox_head import BBoxHead
+
+
+class BasicResBlock(nn.Module):
+ """Basic residual block.
+
+ This block is a little different from the block in the ResNet backbone.
+ The kernel size of conv1 is 1 in this block while 3 in ResNet BasicBlock.
+
+ Args:
+ in_channels (int): Channels of the input feature map.
+ out_channels (int): Channels of the output feature map.
+ conv_cfg (dict): The config dict for convolution layers.
+ norm_cfg (dict): The config dict for normalization layers.
+ """
+
+ def __init__(self,
+ in_channels,
+ out_channels,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN')):
+ super(BasicResBlock, self).__init__()
+
+ # main path
+ self.conv1 = ConvModule(
+ in_channels,
+ in_channels,
+ kernel_size=3,
+ padding=1,
+ bias=False,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg)
+ self.conv2 = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ bias=False,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ # identity path
+ self.conv_identity = ConvModule(
+ in_channels,
+ out_channels,
+ kernel_size=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=None)
+
+ self.relu = nn.ReLU(inplace=True)
+
+ def forward(self, x):
+ identity = x
+
+ x = self.conv1(x)
+ x = self.conv2(x)
+
+ identity = self.conv_identity(identity)
+ out = x + identity
+
+ out = self.relu(out)
+ return out
+
+
+@HEADS.register_module()
+class DoubleConvFCBBoxHead(BBoxHead):
+ r"""Bbox head used in Double-Head R-CNN
+
+ .. code-block:: none
+
+ /-> cls
+ /-> shared convs ->
+ \-> reg
+ roi features
+ /-> cls
+ \-> shared fc ->
+ \-> reg
+ """ # noqa: W605
+
+ def __init__(self,
+ num_convs=0,
+ num_fcs=0,
+ conv_out_channels=1024,
+ fc_out_channels=1024,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ **kwargs):
+ kwargs.setdefault('with_avg_pool', True)
+ super(DoubleConvFCBBoxHead, self).__init__(**kwargs)
+ assert self.with_avg_pool
+ assert num_convs > 0
+ assert num_fcs > 0
+ self.num_convs = num_convs
+ self.num_fcs = num_fcs
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+
+ # increase the channel of input features
+ self.res_block = BasicResBlock(self.in_channels,
+ self.conv_out_channels)
+
+ # add conv heads
+ self.conv_branch = self._add_conv_branch()
+ # add fc heads
+ self.fc_branch = self._add_fc_branch()
+
+ out_dim_reg = 4 if self.reg_class_agnostic else 4 * self.num_classes
+ self.fc_reg = nn.Linear(self.conv_out_channels, out_dim_reg)
+
+ self.fc_cls = nn.Linear(self.fc_out_channels, self.num_classes + 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ def _add_conv_branch(self):
+ """Add the fc branch which consists of a sequential of conv layers."""
+ branch_convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ branch_convs.append(
+ Bottleneck(
+ inplanes=self.conv_out_channels,
+ planes=self.conv_out_channels // 4,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ return branch_convs
+
+ def _add_fc_branch(self):
+ """Add the fc branch which consists of a sequential of fc layers."""
+ branch_fcs = nn.ModuleList()
+ for i in range(self.num_fcs):
+ fc_in_channels = (
+ self.in_channels *
+ self.roi_feat_area if i == 0 else self.fc_out_channels)
+ branch_fcs.append(nn.Linear(fc_in_channels, self.fc_out_channels))
+ return branch_fcs
+
+ def init_weights(self):
+ # conv layers are already initialized by ConvModule
+ normal_init(self.fc_cls, std=0.01)
+ normal_init(self.fc_reg, std=0.001)
+
+ for m in self.fc_branch.modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m, distribution='uniform')
+
+ def forward(self, x_cls, x_reg):
+ # conv head
+ x_conv = self.res_block(x_reg)
+
+ for conv in self.conv_branch:
+ x_conv = conv(x_conv)
+
+ if self.with_avg_pool:
+ x_conv = self.avg_pool(x_conv)
+
+ x_conv = x_conv.view(x_conv.size(0), -1)
+ bbox_pred = self.fc_reg(x_conv)
+
+ # fc head
+ x_fc = x_cls.view(x_cls.size(0), -1)
+ for fc in self.fc_branch:
+ x_fc = self.relu(fc(x_fc))
+
+ cls_score = self.fc_cls(x_fc)
+
+ return cls_score, bbox_pred
diff --git a/mmdet/models/roi_heads/bbox_heads/sabl_head.py b/mmdet/models/roi_heads/bbox_heads/sabl_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5153996aeb706d103d1ad14b61734914eddb7693
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/sabl_head.py
@@ -0,0 +1,572 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, kaiming_init, normal_init, xavier_init
+from mmcv.runner import force_fp32
+
+from mmdet.core import build_bbox_coder, multi_apply, multiclass_nms
+from mmdet.models.builder import HEADS, build_loss
+from mmdet.models.losses import accuracy
+
+
+@HEADS.register_module()
+class SABLHead(nn.Module):
+ """Side-Aware Boundary Localization (SABL) for RoI-Head.
+
+ Side-Aware features are extracted by conv layers
+ with an attention mechanism.
+ Boundary Localization with Bucketing and Bucketing Guided Rescoring
+ are implemented in BucketingBBoxCoder.
+
+ Please refer to https://arxiv.org/abs/1912.04260 for more details.
+
+ Args:
+ cls_in_channels (int): Input channels of cls RoI feature. \
+ Defaults to 256.
+ reg_in_channels (int): Input channels of reg RoI feature. \
+ Defaults to 256.
+ roi_feat_size (int): Size of RoI features. Defaults to 7.
+ reg_feat_up_ratio (int): Upsample ratio of reg features. \
+ Defaults to 2.
+ reg_pre_kernel (int): Kernel of 2D conv layers before \
+ attention pooling. Defaults to 3.
+ reg_post_kernel (int): Kernel of 1D conv layers after \
+ attention pooling. Defaults to 3.
+ reg_pre_num (int): Number of pre convs. Defaults to 2.
+ reg_post_num (int): Number of post convs. Defaults to 1.
+ num_classes (int): Number of classes in dataset. Defaults to 80.
+ cls_out_channels (int): Hidden channels in cls fcs. Defaults to 1024.
+ reg_offset_out_channels (int): Hidden and output channel \
+ of reg offset branch. Defaults to 256.
+ reg_cls_out_channels (int): Hidden and output channel \
+ of reg cls branch. Defaults to 256.
+ num_cls_fcs (int): Number of fcs for cls branch. Defaults to 1.
+ num_reg_fcs (int): Number of fcs for reg branch.. Defaults to 0.
+ reg_class_agnostic (bool): Class agnostic regresion or not. \
+ Defaults to True.
+ norm_cfg (dict): Config of norm layers. Defaults to None.
+ bbox_coder (dict): Config of bbox coder. Defaults 'BucketingBBoxCoder'.
+ loss_cls (dict): Config of classification loss.
+ loss_bbox_cls (dict): Config of classification loss for bbox branch.
+ loss_bbox_reg (dict): Config of regression loss for bbox branch.
+ """
+
+ def __init__(self,
+ num_classes,
+ cls_in_channels=256,
+ reg_in_channels=256,
+ roi_feat_size=7,
+ reg_feat_up_ratio=2,
+ reg_pre_kernel=3,
+ reg_post_kernel=3,
+ reg_pre_num=2,
+ reg_post_num=1,
+ cls_out_channels=1024,
+ reg_offset_out_channels=256,
+ reg_cls_out_channels=256,
+ num_cls_fcs=1,
+ num_reg_fcs=0,
+ reg_class_agnostic=True,
+ norm_cfg=None,
+ bbox_coder=dict(
+ type='BucketingBBoxCoder',
+ num_buckets=14,
+ scale_factor=1.7),
+ loss_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=False,
+ loss_weight=1.0),
+ loss_bbox_cls=dict(
+ type='CrossEntropyLoss',
+ use_sigmoid=True,
+ loss_weight=1.0),
+ loss_bbox_reg=dict(
+ type='SmoothL1Loss', beta=0.1, loss_weight=1.0)):
+ super(SABLHead, self).__init__()
+ self.cls_in_channels = cls_in_channels
+ self.reg_in_channels = reg_in_channels
+ self.roi_feat_size = roi_feat_size
+ self.reg_feat_up_ratio = int(reg_feat_up_ratio)
+ self.num_buckets = bbox_coder['num_buckets']
+ assert self.reg_feat_up_ratio // 2 >= 1
+ self.up_reg_feat_size = roi_feat_size * self.reg_feat_up_ratio
+ assert self.up_reg_feat_size == bbox_coder['num_buckets']
+ self.reg_pre_kernel = reg_pre_kernel
+ self.reg_post_kernel = reg_post_kernel
+ self.reg_pre_num = reg_pre_num
+ self.reg_post_num = reg_post_num
+ self.num_classes = num_classes
+ self.cls_out_channels = cls_out_channels
+ self.reg_offset_out_channels = reg_offset_out_channels
+ self.reg_cls_out_channels = reg_cls_out_channels
+ self.num_cls_fcs = num_cls_fcs
+ self.num_reg_fcs = num_reg_fcs
+ self.reg_class_agnostic = reg_class_agnostic
+ assert self.reg_class_agnostic
+ self.norm_cfg = norm_cfg
+
+ self.bbox_coder = build_bbox_coder(bbox_coder)
+ self.loss_cls = build_loss(loss_cls)
+ self.loss_bbox_cls = build_loss(loss_bbox_cls)
+ self.loss_bbox_reg = build_loss(loss_bbox_reg)
+
+ self.cls_fcs = self._add_fc_branch(self.num_cls_fcs,
+ self.cls_in_channels,
+ self.roi_feat_size,
+ self.cls_out_channels)
+
+ self.side_num = int(np.ceil(self.num_buckets / 2))
+
+ if self.reg_feat_up_ratio > 1:
+ self.upsample_x = nn.ConvTranspose1d(
+ reg_in_channels,
+ reg_in_channels,
+ self.reg_feat_up_ratio,
+ stride=self.reg_feat_up_ratio)
+ self.upsample_y = nn.ConvTranspose1d(
+ reg_in_channels,
+ reg_in_channels,
+ self.reg_feat_up_ratio,
+ stride=self.reg_feat_up_ratio)
+
+ self.reg_pre_convs = nn.ModuleList()
+ for i in range(self.reg_pre_num):
+ reg_pre_conv = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=reg_pre_kernel,
+ padding=reg_pre_kernel // 2,
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_pre_convs.append(reg_pre_conv)
+
+ self.reg_post_conv_xs = nn.ModuleList()
+ for i in range(self.reg_post_num):
+ reg_post_conv_x = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=(1, reg_post_kernel),
+ padding=(0, reg_post_kernel // 2),
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_post_conv_xs.append(reg_post_conv_x)
+ self.reg_post_conv_ys = nn.ModuleList()
+ for i in range(self.reg_post_num):
+ reg_post_conv_y = ConvModule(
+ reg_in_channels,
+ reg_in_channels,
+ kernel_size=(reg_post_kernel, 1),
+ padding=(reg_post_kernel // 2, 0),
+ norm_cfg=norm_cfg,
+ act_cfg=dict(type='ReLU'))
+ self.reg_post_conv_ys.append(reg_post_conv_y)
+
+ self.reg_conv_att_x = nn.Conv2d(reg_in_channels, 1, 1)
+ self.reg_conv_att_y = nn.Conv2d(reg_in_channels, 1, 1)
+
+ self.fc_cls = nn.Linear(self.cls_out_channels, self.num_classes + 1)
+ self.relu = nn.ReLU(inplace=True)
+
+ self.reg_cls_fcs = self._add_fc_branch(self.num_reg_fcs,
+ self.reg_in_channels, 1,
+ self.reg_cls_out_channels)
+ self.reg_offset_fcs = self._add_fc_branch(self.num_reg_fcs,
+ self.reg_in_channels, 1,
+ self.reg_offset_out_channels)
+ self.fc_reg_cls = nn.Linear(self.reg_cls_out_channels, 1)
+ self.fc_reg_offset = nn.Linear(self.reg_offset_out_channels, 1)
+
+ def _add_fc_branch(self, num_branch_fcs, in_channels, roi_feat_size,
+ fc_out_channels):
+ in_channels = in_channels * roi_feat_size * roi_feat_size
+ branch_fcs = nn.ModuleList()
+ for i in range(num_branch_fcs):
+ fc_in_channels = (in_channels if i == 0 else fc_out_channels)
+ branch_fcs.append(nn.Linear(fc_in_channels, fc_out_channels))
+ return branch_fcs
+
+ def init_weights(self):
+ for module_list in [
+ self.reg_cls_fcs, self.reg_offset_fcs, self.cls_fcs
+ ]:
+ for m in module_list.modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m, distribution='uniform')
+ if self.reg_feat_up_ratio > 1:
+ kaiming_init(self.upsample_x, distribution='normal')
+ kaiming_init(self.upsample_y, distribution='normal')
+
+ normal_init(self.reg_conv_att_x, 0, 0.01)
+ normal_init(self.reg_conv_att_y, 0, 0.01)
+ normal_init(self.fc_reg_offset, 0, 0.001)
+ normal_init(self.fc_reg_cls, 0, 0.01)
+ normal_init(self.fc_cls, 0, 0.01)
+
+ def cls_forward(self, cls_x):
+ cls_x = cls_x.view(cls_x.size(0), -1)
+ for fc in self.cls_fcs:
+ cls_x = self.relu(fc(cls_x))
+ cls_score = self.fc_cls(cls_x)
+ return cls_score
+
+ def attention_pool(self, reg_x):
+ """Extract direction-specific features fx and fy with attention
+ methanism."""
+ reg_fx = reg_x
+ reg_fy = reg_x
+ reg_fx_att = self.reg_conv_att_x(reg_fx).sigmoid()
+ reg_fy_att = self.reg_conv_att_y(reg_fy).sigmoid()
+ reg_fx_att = reg_fx_att / reg_fx_att.sum(dim=2).unsqueeze(2)
+ reg_fy_att = reg_fy_att / reg_fy_att.sum(dim=3).unsqueeze(3)
+ reg_fx = (reg_fx * reg_fx_att).sum(dim=2)
+ reg_fy = (reg_fy * reg_fy_att).sum(dim=3)
+ return reg_fx, reg_fy
+
+ def side_aware_feature_extractor(self, reg_x):
+ """Refine and extract side-aware features without split them."""
+ for reg_pre_conv in self.reg_pre_convs:
+ reg_x = reg_pre_conv(reg_x)
+ reg_fx, reg_fy = self.attention_pool(reg_x)
+
+ if self.reg_post_num > 0:
+ reg_fx = reg_fx.unsqueeze(2)
+ reg_fy = reg_fy.unsqueeze(3)
+ for i in range(self.reg_post_num):
+ reg_fx = self.reg_post_conv_xs[i](reg_fx)
+ reg_fy = self.reg_post_conv_ys[i](reg_fy)
+ reg_fx = reg_fx.squeeze(2)
+ reg_fy = reg_fy.squeeze(3)
+ if self.reg_feat_up_ratio > 1:
+ reg_fx = self.relu(self.upsample_x(reg_fx))
+ reg_fy = self.relu(self.upsample_y(reg_fy))
+ reg_fx = torch.transpose(reg_fx, 1, 2)
+ reg_fy = torch.transpose(reg_fy, 1, 2)
+ return reg_fx.contiguous(), reg_fy.contiguous()
+
+ def reg_pred(self, x, offset_fcs, cls_fcs):
+ """Predict bucketing estimation (cls_pred) and fine regression (offset
+ pred) with side-aware features."""
+ x_offset = x.view(-1, self.reg_in_channels)
+ x_cls = x.view(-1, self.reg_in_channels)
+
+ for fc in offset_fcs:
+ x_offset = self.relu(fc(x_offset))
+ for fc in cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+ offset_pred = self.fc_reg_offset(x_offset)
+ cls_pred = self.fc_reg_cls(x_cls)
+
+ offset_pred = offset_pred.view(x.size(0), -1)
+ cls_pred = cls_pred.view(x.size(0), -1)
+
+ return offset_pred, cls_pred
+
+ def side_aware_split(self, feat):
+ """Split side-aware features aligned with orders of bucketing
+ targets."""
+ l_end = int(np.ceil(self.up_reg_feat_size / 2))
+ r_start = int(np.floor(self.up_reg_feat_size / 2))
+ feat_fl = feat[:, :l_end]
+ feat_fr = feat[:, r_start:].flip(dims=(1, ))
+ feat_fl = feat_fl.contiguous()
+ feat_fr = feat_fr.contiguous()
+ feat = torch.cat([feat_fl, feat_fr], dim=-1)
+ return feat
+
+ def bbox_pred_split(self, bbox_pred, num_proposals_per_img):
+ """Split batch bbox prediction back to each image."""
+ bucket_cls_preds, bucket_offset_preds = bbox_pred
+ bucket_cls_preds = bucket_cls_preds.split(num_proposals_per_img, 0)
+ bucket_offset_preds = bucket_offset_preds.split(
+ num_proposals_per_img, 0)
+ bbox_pred = tuple(zip(bucket_cls_preds, bucket_offset_preds))
+ return bbox_pred
+
+ def reg_forward(self, reg_x):
+ outs = self.side_aware_feature_extractor(reg_x)
+ edge_offset_preds = []
+ edge_cls_preds = []
+ reg_fx = outs[0]
+ reg_fy = outs[1]
+ offset_pred_x, cls_pred_x = self.reg_pred(reg_fx, self.reg_offset_fcs,
+ self.reg_cls_fcs)
+ offset_pred_y, cls_pred_y = self.reg_pred(reg_fy, self.reg_offset_fcs,
+ self.reg_cls_fcs)
+ offset_pred_x = self.side_aware_split(offset_pred_x)
+ offset_pred_y = self.side_aware_split(offset_pred_y)
+ cls_pred_x = self.side_aware_split(cls_pred_x)
+ cls_pred_y = self.side_aware_split(cls_pred_y)
+ edge_offset_preds = torch.cat([offset_pred_x, offset_pred_y], dim=-1)
+ edge_cls_preds = torch.cat([cls_pred_x, cls_pred_y], dim=-1)
+
+ return (edge_cls_preds, edge_offset_preds)
+
+ def forward(self, x):
+
+ bbox_pred = self.reg_forward(x)
+ cls_score = self.cls_forward(x)
+
+ return cls_score, bbox_pred
+
+ def get_targets(self, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg):
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ neg_proposals = [res.neg_bboxes for res in sampling_results]
+ pos_gt_bboxes = [res.pos_gt_bboxes for res in sampling_results]
+ pos_gt_labels = [res.pos_gt_labels for res in sampling_results]
+ cls_reg_targets = self.bucket_target(pos_proposals, neg_proposals,
+ pos_gt_bboxes, pos_gt_labels,
+ rcnn_train_cfg)
+ (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights) = cls_reg_targets
+ return (labels, label_weights, (bucket_cls_targets,
+ bucket_offset_targets),
+ (bucket_cls_weights, bucket_offset_weights))
+
+ def bucket_target(self,
+ pos_proposals_list,
+ neg_proposals_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ rcnn_train_cfg,
+ concat=True):
+ (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights) = multi_apply(
+ self._bucket_target_single,
+ pos_proposals_list,
+ neg_proposals_list,
+ pos_gt_bboxes_list,
+ pos_gt_labels_list,
+ cfg=rcnn_train_cfg)
+
+ if concat:
+ labels = torch.cat(labels, 0)
+ label_weights = torch.cat(label_weights, 0)
+ bucket_cls_targets = torch.cat(bucket_cls_targets, 0)
+ bucket_cls_weights = torch.cat(bucket_cls_weights, 0)
+ bucket_offset_targets = torch.cat(bucket_offset_targets, 0)
+ bucket_offset_weights = torch.cat(bucket_offset_weights, 0)
+ return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights)
+
+ def _bucket_target_single(self, pos_proposals, neg_proposals,
+ pos_gt_bboxes, pos_gt_labels, cfg):
+ """Compute bucketing estimation targets and fine regression targets for
+ a single image.
+
+ Args:
+ pos_proposals (Tensor): positive proposals of a single image,
+ Shape (n_pos, 4)
+ neg_proposals (Tensor): negative proposals of a single image,
+ Shape (n_neg, 4).
+ pos_gt_bboxes (Tensor): gt bboxes assigned to positive proposals
+ of a single image, Shape (n_pos, 4).
+ pos_gt_labels (Tensor): gt labels assigned to positive proposals
+ of a single image, Shape (n_pos, ).
+ cfg (dict): Config of calculating targets
+
+ Returns:
+ tuple:
+
+ - labels (Tensor): Labels in a single image. \
+ Shape (n,).
+ - label_weights (Tensor): Label weights in a single image.\
+ Shape (n,)
+ - bucket_cls_targets (Tensor): Bucket cls targets in \
+ a single image. Shape (n, num_buckets*2).
+ - bucket_cls_weights (Tensor): Bucket cls weights in \
+ a single image. Shape (n, num_buckets*2).
+ - bucket_offset_targets (Tensor): Bucket offset targets \
+ in a single image. Shape (n, num_buckets*2).
+ - bucket_offset_targets (Tensor): Bucket offset weights \
+ in a single image. Shape (n, num_buckets*2).
+ """
+ num_pos = pos_proposals.size(0)
+ num_neg = neg_proposals.size(0)
+ num_samples = num_pos + num_neg
+ labels = pos_gt_bboxes.new_full((num_samples, ),
+ self.num_classes,
+ dtype=torch.long)
+ label_weights = pos_proposals.new_zeros(num_samples)
+ bucket_cls_targets = pos_proposals.new_zeros(num_samples,
+ 4 * self.side_num)
+ bucket_cls_weights = pos_proposals.new_zeros(num_samples,
+ 4 * self.side_num)
+ bucket_offset_targets = pos_proposals.new_zeros(
+ num_samples, 4 * self.side_num)
+ bucket_offset_weights = pos_proposals.new_zeros(
+ num_samples, 4 * self.side_num)
+ if num_pos > 0:
+ labels[:num_pos] = pos_gt_labels
+ label_weights[:num_pos] = 1.0
+ (pos_bucket_offset_targets, pos_bucket_offset_weights,
+ pos_bucket_cls_targets,
+ pos_bucket_cls_weights) = self.bbox_coder.encode(
+ pos_proposals, pos_gt_bboxes)
+ bucket_cls_targets[:num_pos, :] = pos_bucket_cls_targets
+ bucket_cls_weights[:num_pos, :] = pos_bucket_cls_weights
+ bucket_offset_targets[:num_pos, :] = pos_bucket_offset_targets
+ bucket_offset_weights[:num_pos, :] = pos_bucket_offset_weights
+ if num_neg > 0:
+ label_weights[-num_neg:] = 1.0
+ return (labels, label_weights, bucket_cls_targets, bucket_cls_weights,
+ bucket_offset_targets, bucket_offset_weights)
+
+ def loss(self,
+ cls_score,
+ bbox_pred,
+ rois,
+ labels,
+ label_weights,
+ bbox_targets,
+ bbox_weights,
+ reduction_override=None):
+ losses = dict()
+ if cls_score is not None:
+ avg_factor = max(torch.sum(label_weights > 0).float().item(), 1.)
+ losses['loss_cls'] = self.loss_cls(
+ cls_score,
+ labels,
+ label_weights,
+ avg_factor=avg_factor,
+ reduction_override=reduction_override)
+ losses['acc'] = accuracy(cls_score, labels)
+
+ if bbox_pred is not None:
+ bucket_cls_preds, bucket_offset_preds = bbox_pred
+ bucket_cls_targets, bucket_offset_targets = bbox_targets
+ bucket_cls_weights, bucket_offset_weights = bbox_weights
+ # edge cls
+ bucket_cls_preds = bucket_cls_preds.view(-1, self.side_num)
+ bucket_cls_targets = bucket_cls_targets.view(-1, self.side_num)
+ bucket_cls_weights = bucket_cls_weights.view(-1, self.side_num)
+ losses['loss_bbox_cls'] = self.loss_bbox_cls(
+ bucket_cls_preds,
+ bucket_cls_targets,
+ bucket_cls_weights,
+ avg_factor=bucket_cls_targets.size(0),
+ reduction_override=reduction_override)
+
+ losses['loss_bbox_reg'] = self.loss_bbox_reg(
+ bucket_offset_preds,
+ bucket_offset_targets,
+ bucket_offset_weights,
+ avg_factor=bucket_offset_targets.size(0),
+ reduction_override=reduction_override)
+
+ return losses
+
+ @force_fp32(apply_to=('cls_score', 'bbox_pred'))
+ def get_bboxes(self,
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None):
+ if isinstance(cls_score, list):
+ cls_score = sum(cls_score) / float(len(cls_score))
+ scores = F.softmax(cls_score, dim=1) if cls_score is not None else None
+
+ if bbox_pred is not None:
+ bboxes, confids = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
+ img_shape)
+ else:
+ bboxes = rois[:, 1:].clone()
+ confids = None
+ if img_shape is not None:
+ bboxes[:, [0, 2]].clamp_(min=0, max=img_shape[1] - 1)
+ bboxes[:, [1, 3]].clamp_(min=0, max=img_shape[0] - 1)
+
+ if rescale and bboxes.size(0) > 0:
+ if isinstance(scale_factor, float):
+ bboxes /= scale_factor
+ else:
+ bboxes /= torch.from_numpy(scale_factor).to(bboxes.device)
+
+ if cfg is None:
+ return bboxes, scores
+ else:
+ det_bboxes, det_labels = multiclass_nms(
+ bboxes,
+ scores,
+ cfg.score_thr,
+ cfg.nms,
+ cfg.max_per_img,
+ score_factors=confids)
+
+ return det_bboxes, det_labels
+
+ @force_fp32(apply_to=('bbox_preds', ))
+ def refine_bboxes(self, rois, labels, bbox_preds, pos_is_gts, img_metas):
+ """Refine bboxes during training.
+
+ Args:
+ rois (Tensor): Shape (n*bs, 5), where n is image number per GPU,
+ and bs is the sampled RoIs per image.
+ labels (Tensor): Shape (n*bs, ).
+ bbox_preds (list[Tensor]): Shape [(n*bs, num_buckets*2), \
+ (n*bs, num_buckets*2)].
+ pos_is_gts (list[Tensor]): Flags indicating if each positive bbox
+ is a gt bbox.
+ img_metas (list[dict]): Meta info of each image.
+
+ Returns:
+ list[Tensor]: Refined bboxes of each image in a mini-batch.
+ """
+ img_ids = rois[:, 0].long().unique(sorted=True)
+ assert img_ids.numel() == len(img_metas)
+
+ bboxes_list = []
+ for i in range(len(img_metas)):
+ inds = torch.nonzero(
+ rois[:, 0] == i, as_tuple=False).squeeze(dim=1)
+ num_rois = inds.numel()
+
+ bboxes_ = rois[inds, 1:]
+ label_ = labels[inds]
+ edge_cls_preds, edge_offset_preds = bbox_preds
+ edge_cls_preds_ = edge_cls_preds[inds]
+ edge_offset_preds_ = edge_offset_preds[inds]
+ bbox_pred_ = [edge_cls_preds_, edge_offset_preds_]
+ img_meta_ = img_metas[i]
+ pos_is_gts_ = pos_is_gts[i]
+
+ bboxes = self.regress_by_class(bboxes_, label_, bbox_pred_,
+ img_meta_)
+ # filter gt bboxes
+ pos_keep = 1 - pos_is_gts_
+ keep_inds = pos_is_gts_.new_ones(num_rois)
+ keep_inds[:len(pos_is_gts_)] = pos_keep
+
+ bboxes_list.append(bboxes[keep_inds.type(torch.bool)])
+
+ return bboxes_list
+
+ @force_fp32(apply_to=('bbox_pred', ))
+ def regress_by_class(self, rois, label, bbox_pred, img_meta):
+ """Regress the bbox for the predicted class. Used in Cascade R-CNN.
+
+ Args:
+ rois (Tensor): shape (n, 4) or (n, 5)
+ label (Tensor): shape (n, )
+ bbox_pred (list[Tensor]): shape [(n, num_buckets *2), \
+ (n, num_buckets *2)]
+ img_meta (dict): Image meta info.
+
+ Returns:
+ Tensor: Regressed bboxes, the same shape as input rois.
+ """
+ assert rois.size(1) == 4 or rois.size(1) == 5
+
+ if rois.size(1) == 4:
+ new_rois, _ = self.bbox_coder.decode(rois, bbox_pred,
+ img_meta['img_shape'])
+ else:
+ bboxes, _ = self.bbox_coder.decode(rois[:, 1:], bbox_pred,
+ img_meta['img_shape'])
+ new_rois = torch.cat((rois[:, [0]], bboxes), dim=1)
+
+ return new_rois
diff --git a/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..35758f4f4e3b2bddd460edb8a7f482b3a9da2919
--- /dev/null
+++ b/mmdet/models/roi_heads/bbox_heads/scnet_bbox_head.py
@@ -0,0 +1,76 @@
+from mmdet.models.builder import HEADS
+from .convfc_bbox_head import ConvFCBBoxHead
+
+
+@HEADS.register_module()
+class SCNetBBoxHead(ConvFCBBoxHead):
+ """BBox head for `SCNet `_.
+
+ This inherits ``ConvFCBBoxHead`` with modified forward() function, allow us
+ to get intermediate shared feature.
+ """
+
+ def _forward_shared(self, x):
+ """Forward function for shared part."""
+ if self.num_shared_convs > 0:
+ for conv in self.shared_convs:
+ x = conv(x)
+
+ if self.num_shared_fcs > 0:
+ if self.with_avg_pool:
+ x = self.avg_pool(x)
+
+ x = x.flatten(1)
+
+ for fc in self.shared_fcs:
+ x = self.relu(fc(x))
+
+ return x
+
+ def _forward_cls_reg(self, x):
+ """Forward function for classification and regression parts."""
+ x_cls = x
+ x_reg = x
+
+ for conv in self.cls_convs:
+ x_cls = conv(x_cls)
+ if x_cls.dim() > 2:
+ if self.with_avg_pool:
+ x_cls = self.avg_pool(x_cls)
+ x_cls = x_cls.flatten(1)
+ for fc in self.cls_fcs:
+ x_cls = self.relu(fc(x_cls))
+
+ for conv in self.reg_convs:
+ x_reg = conv(x_reg)
+ if x_reg.dim() > 2:
+ if self.with_avg_pool:
+ x_reg = self.avg_pool(x_reg)
+ x_reg = x_reg.flatten(1)
+ for fc in self.reg_fcs:
+ x_reg = self.relu(fc(x_reg))
+
+ cls_score = self.fc_cls(x_cls) if self.with_cls else None
+ bbox_pred = self.fc_reg(x_reg) if self.with_reg else None
+
+ return cls_score, bbox_pred
+
+ def forward(self, x, return_shared_feat=False):
+ """Forward function.
+
+ Args:
+ x (Tensor): input features
+ return_shared_feat (bool): If True, return cls-reg-shared feature.
+
+ Return:
+ out (tuple[Tensor]): contain ``cls_score`` and ``bbox_pred``,
+ if ``return_shared_feat`` is True, append ``x_shared`` to the
+ returned tuple.
+ """
+ x_shared = self._forward_shared(x)
+ out = self._forward_cls_reg(x_shared)
+
+ if return_shared_feat:
+ out += (x_shared, )
+
+ return out
diff --git a/mmdet/models/roi_heads/cascade_roi_head.py b/mmdet/models/roi_heads/cascade_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..45b6f36a386cd37c50cc43666fcc516f2e14d868
--- /dev/null
+++ b/mmdet/models/roi_heads/cascade_roi_head.py
@@ -0,0 +1,507 @@
+import torch
+import torch.nn as nn
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, build_assigner,
+ build_sampler, merge_aug_bboxes, merge_aug_masks,
+ multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module()
+class CascadeRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+ """Cascade roi head including one bbox head and one mask head.
+
+ https://arxiv.org/abs/1712.00726
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ bbox_roi_extractor=None,
+ bbox_head=None,
+ mask_roi_extractor=None,
+ mask_head=None,
+ shared_head=None,
+ train_cfg=None,
+ test_cfg=None):
+ assert bbox_roi_extractor is not None
+ assert bbox_head is not None
+ assert shared_head is None, \
+ 'Shared head is not supported in Cascade RCNN anymore'
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ super(CascadeRoIHead, self).__init__(
+ bbox_roi_extractor=bbox_roi_extractor,
+ bbox_head=bbox_head,
+ mask_roi_extractor=mask_roi_extractor,
+ mask_head=mask_head,
+ shared_head=shared_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg)
+
+ def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+ """Initialize box head and box roi extractor.
+
+ Args:
+ bbox_roi_extractor (dict): Config of box roi extractor.
+ bbox_head (dict): Config of box in box head.
+ """
+ self.bbox_roi_extractor = nn.ModuleList()
+ self.bbox_head = nn.ModuleList()
+ if not isinstance(bbox_roi_extractor, list):
+ bbox_roi_extractor = [
+ bbox_roi_extractor for _ in range(self.num_stages)
+ ]
+ if not isinstance(bbox_head, list):
+ bbox_head = [bbox_head for _ in range(self.num_stages)]
+ assert len(bbox_roi_extractor) == len(bbox_head) == self.num_stages
+ for roi_extractor, head in zip(bbox_roi_extractor, bbox_head):
+ self.bbox_roi_extractor.append(build_roi_extractor(roi_extractor))
+ self.bbox_head.append(build_head(head))
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize mask head and mask roi extractor.
+
+ Args:
+ mask_roi_extractor (dict): Config of mask roi extractor.
+ mask_head (dict): Config of mask in mask head.
+ """
+ self.mask_head = nn.ModuleList()
+ if not isinstance(mask_head, list):
+ mask_head = [mask_head for _ in range(self.num_stages)]
+ assert len(mask_head) == self.num_stages
+ for head in mask_head:
+ self.mask_head.append(build_head(head))
+ if mask_roi_extractor is not None:
+ self.share_roi_extractor = False
+ self.mask_roi_extractor = nn.ModuleList()
+ if not isinstance(mask_roi_extractor, list):
+ mask_roi_extractor = [
+ mask_roi_extractor for _ in range(self.num_stages)
+ ]
+ assert len(mask_roi_extractor) == self.num_stages
+ for roi_extractor in mask_roi_extractor:
+ self.mask_roi_extractor.append(
+ build_roi_extractor(roi_extractor))
+ else:
+ self.share_roi_extractor = True
+ self.mask_roi_extractor = self.bbox_roi_extractor
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler for each stage."""
+ self.bbox_assigner = []
+ self.bbox_sampler = []
+ if self.train_cfg is not None:
+ for idx, rcnn_train_cfg in enumerate(self.train_cfg):
+ self.bbox_assigner.append(
+ build_assigner(rcnn_train_cfg.assigner))
+ self.current_stage = idx
+ self.bbox_sampler.append(
+ build_sampler(rcnn_train_cfg.sampler, context=self))
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if self.with_shared_head:
+ self.shared_head.init_weights(pretrained=pretrained)
+ for i in range(self.num_stages):
+ if self.with_bbox:
+ self.bbox_roi_extractor[i].init_weights()
+ self.bbox_head[i].init_weights()
+ if self.with_mask:
+ if not self.share_roi_extractor:
+ self.mask_roi_extractor[i].init_weights()
+ self.mask_head[i].init_weights()
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask heads
+ if self.with_mask:
+ mask_rois = rois[:100]
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def _bbox_forward(self, stage, x, rois):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+ rois)
+ # do not support caffe_c4 model anymore
+ cls_score, bbox_pred = bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self, stage, x, sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(stage, x, rois)
+ bbox_targets = self.bbox_head[stage].get_targets(
+ sampling_results, gt_bboxes, gt_labels, rcnn_train_cfg)
+ loss_bbox = self.bbox_head[stage].loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
+ return bbox_results
+
+ def _mask_forward(self, stage, x, rois):
+ """Mask head forward function used in both training and testing."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+ rois)
+ # do not support caffe_c4 model anymore
+ mask_pred = mask_head(mask_feats)
+
+ mask_results = dict(mask_pred=mask_pred)
+ return mask_results
+
+ def _mask_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ bbox_feats=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(stage, x, pos_rois)
+
+ mask_targets = self.mask_head[stage].get_targets(
+ sampling_results, gt_masks, rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head[stage].loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results.update(loss_mask=loss_mask)
+ return mask_results
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): list of region proposals.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ losses = dict()
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ if self.with_bbox or self.with_mask:
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(
+ proposal_list[j], gt_bboxes[j], gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ # bbox head forward and loss
+ bbox_results = self._bbox_forward_train(i, x, sampling_results,
+ gt_bboxes, gt_labels,
+ rcnn_train_cfg)
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(
+ i, x, sampling_results, gt_masks, rcnn_train_cfg,
+ bbox_results['bbox_feats'])
+ for name, value in mask_results['loss_mask'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine bboxes
+ if i < self.num_stages - 1:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ # bbox_targets is a tuple
+ roi_labels = bbox_results['bbox_targets'][0]
+ with torch.no_grad():
+ roi_labels = torch.where(
+ roi_labels == self.bbox_head[i].num_classes,
+ bbox_results['cls_score'][:, :-1].argmax(1),
+ roi_labels)
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_bbox_result = {}
+ ms_segm_result = {}
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(
+ len(proposals) for proposals in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ if isinstance(bbox_pred, torch.Tensor):
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ else:
+ bbox_pred = self.bbox_head[i].bbox_pred_split(
+ bbox_pred, num_proposals_per_img)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score]
+ rois = torch.cat([
+ self.bbox_head[i].regress_by_class(rois[j], bbox_label[j],
+ bbox_pred[j],
+ img_metas[j])
+ for j in range(num_imgs)
+ ])
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+
+ if torch.onnx.is_in_onnx_export():
+ return det_bboxes, det_labels
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+ ms_bbox_result['ensemble'] = bbox_results
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ num_mask_rois_per_img = tuple(
+ _bbox.size(0) for _bbox in _bboxes)
+ aug_masks = []
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ # split batch mask prediction back to each image
+ mask_pred = mask_pred.split(num_mask_rois_per_img, 0)
+ aug_masks.append(
+ [m.sigmoid().cpu().numpy() for m in mask_pred])
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[]
+ for _ in range(self.mask_head[-1].num_classes)])
+ else:
+ aug_mask = [mask[i] for mask in aug_masks]
+ merged_masks = merge_aug_masks(
+ aug_mask, [[img_metas[i]]] * self.num_stages,
+ rcnn_test_cfg)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks, _bboxes[i], det_labels[i],
+ rcnn_test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ segm_results.append(segm_result)
+ ms_segm_result['ensemble'] = segm_results
+
+ if self.with_mask:
+ results = list(
+ zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
+ else:
+ results = ms_bbox_result['ensemble']
+
+ return results
+
+ def aug_test(self, features, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(features, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(i, x, rois)
+ ms_scores.append(bbox_results['cls_score'])
+
+ if i < self.num_stages - 1:
+ bbox_label = bbox_results['cls_score'][:, :-1].argmax(
+ dim=1)
+ rois = self.bbox_head[i].regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ bbox_result = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[[]
+ for _ in range(self.mask_head[-1].num_classes)]
+ ]
+ else:
+ aug_masks = []
+ aug_img_metas = []
+ for x, img_meta in zip(features, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ for i in range(self.num_stages):
+ mask_results = self._mask_forward(i, x, mask_rois)
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ aug_img_metas.append(img_meta)
+ merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+ self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return [(bbox_result, segm_result)]
+ else:
+ return [bbox_result]
diff --git a/mmdet/models/roi_heads/double_roi_head.py b/mmdet/models/roi_heads/double_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1aa6c8244a889fbbed312a89574c3e11be294f0
--- /dev/null
+++ b/mmdet/models/roi_heads/double_roi_head.py
@@ -0,0 +1,33 @@
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class DoubleHeadRoIHead(StandardRoIHead):
+ """RoI head for Double Head RCNN.
+
+ https://arxiv.org/abs/1904.06493
+ """
+
+ def __init__(self, reg_roi_scale_factor, **kwargs):
+ super(DoubleHeadRoIHead, self).__init__(**kwargs)
+ self.reg_roi_scale_factor = reg_roi_scale_factor
+
+ def _bbox_forward(self, x, rois):
+ """Box head forward function used in both training and testing time."""
+ bbox_cls_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ bbox_reg_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs],
+ rois,
+ roi_scale_factor=self.reg_roi_scale_factor)
+ if self.with_shared_head:
+ bbox_cls_feats = self.shared_head(bbox_cls_feats)
+ bbox_reg_feats = self.shared_head(bbox_reg_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_cls_feats, bbox_reg_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score,
+ bbox_pred=bbox_pred,
+ bbox_feats=bbox_cls_feats)
+ return bbox_results
diff --git a/mmdet/models/roi_heads/dynamic_roi_head.py b/mmdet/models/roi_heads/dynamic_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..89427a931f45f5a920c0e66fd88058bf9fa05f5c
--- /dev/null
+++ b/mmdet/models/roi_heads/dynamic_roi_head.py
@@ -0,0 +1,154 @@
+import numpy as np
+import torch
+
+from mmdet.core import bbox2roi
+from mmdet.models.losses import SmoothL1Loss
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+EPS = 1e-15
+
+
+@HEADS.register_module()
+class DynamicRoIHead(StandardRoIHead):
+ """RoI head for `Dynamic R-CNN `_."""
+
+ def __init__(self, **kwargs):
+ super(DynamicRoIHead, self).__init__(**kwargs)
+ assert isinstance(self.bbox_head.loss_bbox, SmoothL1Loss)
+ # the IoU history of the past `update_iter_interval` iterations
+ self.iou_history = []
+ # the beta history of the past `update_iter_interval` iterations
+ self.beta_history = []
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """Forward function for training.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ proposals (list[Tensors]): list of region proposals.
+
+ gt_bboxes (list[Tensor]): each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ cur_iou = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ # record the `iou_topk`-th largest IoU in an image
+ iou_topk = min(self.train_cfg.dynamic_rcnn.iou_topk,
+ len(assign_result.max_overlaps))
+ ious, _ = torch.topk(assign_result.max_overlaps, iou_topk)
+ cur_iou.append(ious[-1].item())
+ sampling_results.append(sampling_result)
+ # average the current IoUs over images
+ cur_iou = np.mean(cur_iou)
+ self.iou_history.append(cur_iou)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ # update IoU threshold and SmoothL1 beta
+ update_iter_interval = self.train_cfg.dynamic_rcnn.update_iter_interval
+ if len(self.iou_history) % update_iter_interval == 0:
+ new_iou_thr, new_beta = self.update_hyperparameters()
+
+ return losses
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ num_imgs = len(img_metas)
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+ # record the `beta_topk`-th smallest target
+ # `bbox_targets[2]` and `bbox_targets[3]` stand for bbox_targets
+ # and bbox_weights, respectively
+ pos_inds = bbox_targets[3][:, 0].nonzero().squeeze(1)
+ num_pos = len(pos_inds)
+ cur_target = bbox_targets[2][pos_inds, :2].abs().mean(dim=1)
+ beta_topk = min(self.train_cfg.dynamic_rcnn.beta_topk * num_imgs,
+ num_pos)
+ cur_target = torch.kthvalue(cur_target, beta_topk)[0].item()
+ self.beta_history.append(cur_target)
+ loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
+
+ def update_hyperparameters(self):
+ """Update hyperparameters like IoU thresholds for assigner and beta for
+ SmoothL1 loss based on the training statistics.
+
+ Returns:
+ tuple[float]: the updated ``iou_thr`` and ``beta``.
+ """
+ new_iou_thr = max(self.train_cfg.dynamic_rcnn.initial_iou,
+ np.mean(self.iou_history))
+ self.iou_history = []
+ self.bbox_assigner.pos_iou_thr = new_iou_thr
+ self.bbox_assigner.neg_iou_thr = new_iou_thr
+ self.bbox_assigner.min_pos_iou = new_iou_thr
+ if (np.median(self.beta_history) < EPS):
+ # avoid 0 or too small value for new_beta
+ new_beta = self.bbox_head.loss_bbox.beta
+ else:
+ new_beta = min(self.train_cfg.dynamic_rcnn.initial_beta,
+ np.median(self.beta_history))
+ self.beta_history = []
+ self.bbox_head.loss_bbox.beta = new_beta
+ return new_iou_thr, new_beta
diff --git a/mmdet/models/roi_heads/grid_roi_head.py b/mmdet/models/roi_heads/grid_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4c52c79863ebaf17bd023382c7e5d4c237b4da77
--- /dev/null
+++ b/mmdet/models/roi_heads/grid_roi_head.py
@@ -0,0 +1,176 @@
+import torch
+
+from mmdet.core import bbox2result, bbox2roi
+from ..builder import HEADS, build_head, build_roi_extractor
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class GridRoIHead(StandardRoIHead):
+ """Grid roi head for Grid R-CNN.
+
+ https://arxiv.org/abs/1811.12030
+ """
+
+ def __init__(self, grid_roi_extractor, grid_head, **kwargs):
+ assert grid_head is not None
+ super(GridRoIHead, self).__init__(**kwargs)
+ if grid_roi_extractor is not None:
+ self.grid_roi_extractor = build_roi_extractor(grid_roi_extractor)
+ self.share_roi_extractor = False
+ else:
+ self.share_roi_extractor = True
+ self.grid_roi_extractor = self.bbox_roi_extractor
+ self.grid_head = build_head(grid_head)
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(GridRoIHead, self).init_weights(pretrained)
+ self.grid_head.init_weights()
+ if not self.share_roi_extractor:
+ self.grid_roi_extractor.init_weights()
+
+ def _random_jitter(self, sampling_results, img_metas, amplitude=0.15):
+ """Ramdom jitter positive proposals for training."""
+ for sampling_result, img_meta in zip(sampling_results, img_metas):
+ bboxes = sampling_result.pos_bboxes
+ random_offsets = bboxes.new_empty(bboxes.shape[0], 4).uniform_(
+ -amplitude, amplitude)
+ # before jittering
+ cxcy = (bboxes[:, 2:4] + bboxes[:, :2]) / 2
+ wh = (bboxes[:, 2:4] - bboxes[:, :2]).abs()
+ # after jittering
+ new_cxcy = cxcy + wh * random_offsets[:, :2]
+ new_wh = wh * (1 + random_offsets[:, 2:])
+ # xywh to xyxy
+ new_x1y1 = (new_cxcy - new_wh / 2)
+ new_x2y2 = (new_cxcy + new_wh / 2)
+ new_bboxes = torch.cat([new_x1y1, new_x2y2], dim=1)
+ # clip bboxes
+ max_shape = img_meta['img_shape']
+ if max_shape is not None:
+ new_bboxes[:, 0::2].clamp_(min=0, max=max_shape[1] - 1)
+ new_bboxes[:, 1::2].clamp_(min=0, max=max_shape[0] - 1)
+
+ sampling_result.pos_bboxes = new_bboxes
+ return sampling_results
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ bbox_results = self._bbox_forward(x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+
+ # grid head
+ grid_rois = rois[:100]
+ grid_feats = self.grid_roi_extractor(
+ x[:self.grid_roi_extractor.num_inputs], grid_rois)
+ if self.with_shared_head:
+ grid_feats = self.shared_head(grid_feats)
+ grid_pred = self.grid_head(grid_feats)
+ outs = outs + (grid_pred, )
+
+ # mask head
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_results = self._mask_forward(x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_results = super(GridRoIHead,
+ self)._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+
+ # Grid head forward and loss
+ sampling_results = self._random_jitter(sampling_results, img_metas)
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+
+ # GN in head does not support zero shape input
+ if pos_rois.shape[0] == 0:
+ return bbox_results
+
+ grid_feats = self.grid_roi_extractor(
+ x[:self.grid_roi_extractor.num_inputs], pos_rois)
+ if self.with_shared_head:
+ grid_feats = self.shared_head(grid_feats)
+ # Accelerate training
+ max_sample_num_grid = self.train_cfg.get('max_num_grid', 192)
+ sample_idx = torch.randperm(
+ grid_feats.shape[0])[:min(grid_feats.shape[0], max_sample_num_grid
+ )]
+ grid_feats = grid_feats[sample_idx]
+
+ grid_pred = self.grid_head(grid_feats)
+
+ grid_targets = self.grid_head.get_targets(sampling_results,
+ self.train_cfg)
+ grid_targets = grid_targets[sample_idx]
+
+ loss_grid = self.grid_head.loss(grid_pred, grid_targets)
+
+ bbox_results['loss_bbox'].update(loss_grid)
+ return bbox_results
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=False)
+ # pack rois into bboxes
+ grid_rois = bbox2roi([det_bbox[:, :4] for det_bbox in det_bboxes])
+ if grid_rois.shape[0] != 0:
+ grid_feats = self.grid_roi_extractor(
+ x[:len(self.grid_roi_extractor.featmap_strides)], grid_rois)
+ self.grid_head.test_mode = True
+ grid_pred = self.grid_head(grid_feats)
+ # split batch grid head prediction back to each image
+ num_roi_per_img = tuple(len(det_bbox) for det_bbox in det_bboxes)
+ grid_pred = {
+ k: v.split(num_roi_per_img, 0)
+ for k, v in grid_pred.items()
+ }
+
+ # apply bbox post-processing to each image individually
+ bbox_results = []
+ num_imgs = len(det_bboxes)
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ bbox_results.append(grid_rois.new_tensor([]))
+ else:
+ det_bbox = self.grid_head.get_bboxes(
+ det_bboxes[i], grid_pred['fused'][i], [img_metas[i]])
+ if rescale:
+ det_bbox[:, :4] /= img_metas[i]['scale_factor']
+ bbox_results.append(
+ bbox2result(det_bbox, det_labels[i],
+ self.bbox_head.num_classes))
+ else:
+ bbox_results = [
+ grid_rois.new_tensor([]) for _ in range(len(det_bboxes))
+ ]
+
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return list(zip(bbox_results, segm_results))
diff --git a/mmdet/models/roi_heads/htc_roi_head.py b/mmdet/models/roi_heads/htc_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b5c2ec3bc9d579061fbd89f8b320e6e59909143
--- /dev/null
+++ b/mmdet/models/roi_heads/htc_roi_head.py
@@ -0,0 +1,589 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class HybridTaskCascadeRoIHead(CascadeRoIHead):
+ """Hybrid task cascade roi head including one bbox head and one mask head.
+
+ https://arxiv.org/abs/1901.07518
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ semantic_roi_extractor=None,
+ semantic_head=None,
+ semantic_fusion=('bbox', 'mask'),
+ interleaved=True,
+ mask_info_flow=True,
+ **kwargs):
+ super(HybridTaskCascadeRoIHead,
+ self).__init__(num_stages, stage_loss_weights, **kwargs)
+ assert self.with_bbox and self.with_mask
+ assert not self.with_shared_head # shared head is not supported
+
+ if semantic_head is not None:
+ self.semantic_roi_extractor = build_roi_extractor(
+ semantic_roi_extractor)
+ self.semantic_head = build_head(semantic_head)
+
+ self.semantic_fusion = semantic_fusion
+ self.interleaved = interleaved
+ self.mask_info_flow = mask_info_flow
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(HybridTaskCascadeRoIHead, self).init_weights(pretrained)
+ if self.with_semantic:
+ self.semantic_head.init_weights()
+
+ @property
+ def with_semantic(self):
+ """bool: whether the head has semantic head"""
+ if hasattr(self, 'semantic_head') and self.semantic_head is not None:
+ return True
+ else:
+ return False
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ outs = ()
+ # semantic head
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+ # bbox heads
+ rois = bbox2roi([proposals])
+ for i in range(self.num_stages):
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic_feat)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask heads
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_roi_extractor = self.mask_roi_extractor[-1]
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ mask_feats += mask_semantic_feat
+ last_feat = None
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+ outs = outs + (mask_pred, )
+ return outs
+
+ def _bbox_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ semantic_feat=None):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_head = self.bbox_head[stage]
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(
+ stage, x, rois, semantic_feat=semantic_feat)
+
+ bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg)
+ loss_bbox = bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox,
+ rois=rois,
+ bbox_targets=bbox_targets,
+ )
+ return bbox_results
+
+ def _mask_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ semantic_feat=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_feats = mask_roi_extractor(x[:mask_roi_extractor.num_inputs],
+ pos_rois)
+
+ # semantic feature fusion
+ # element-wise sum for original features and pooled semantic features
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ pos_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats += mask_semantic_feat
+
+ # mask information flow
+ # forward all previous mask heads to obtain last_feat, and fuse it
+ # with the normal mask feature
+ if self.mask_info_flow:
+ last_feat = None
+ for i in range(stage):
+ last_feat = self.mask_head[i](
+ mask_feats, last_feat, return_logits=False)
+ mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+ else:
+ mask_pred = mask_head(mask_feats, return_feat=False)
+
+ mask_targets = mask_head.get_targets(sampling_results, gt_masks,
+ rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = mask_head.loss(mask_pred, mask_targets, pos_labels)
+
+ mask_results = dict(loss_mask=loss_mask)
+ return mask_results
+
+ def _bbox_forward(self, stage, x, rois, semantic_feat=None):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(
+ x[:len(bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_semantic and 'bbox' in self.semantic_fusion:
+ bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
+ bbox_semantic_feat = F.adaptive_avg_pool2d(
+ bbox_semantic_feat, bbox_feats.shape[-2:])
+ bbox_feats += bbox_semantic_feat
+ cls_score, bbox_pred = bbox_head(bbox_feats)
+
+ bbox_results = dict(cls_score=cls_score, bbox_pred=bbox_pred)
+ return bbox_results
+
+ def _mask_forward_test(self, stage, x, bboxes, semantic_feat=None):
+ """Mask head forward function for testing."""
+ mask_roi_extractor = self.mask_roi_extractor[stage]
+ mask_head = self.mask_head[stage]
+ mask_rois = bbox2roi([bboxes])
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ mask_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats += mask_semantic_feat
+ if self.mask_info_flow:
+ last_feat = None
+ last_pred = None
+ for i in range(stage):
+ mask_pred, last_feat = self.mask_head[i](mask_feats, last_feat)
+ if last_pred is not None:
+ mask_pred = mask_pred + last_pred
+ last_pred = mask_pred
+ mask_pred = mask_head(mask_feats, last_feat, return_feat=False)
+ if last_pred is not None:
+ mask_pred = mask_pred + last_pred
+ else:
+ mask_pred = mask_head(mask_feats)
+ return mask_pred
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ gt_semantic_seg=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ proposal_list (list[Tensors]): list of region proposals.
+
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None, list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None, Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # semantic segmentation part
+ # 2 outputs: segmentation prediction and embedded features
+ losses = dict()
+ if self.with_semantic:
+ semantic_pred, semantic_feat = self.semantic_head(x)
+ loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
+ losses['loss_semantic_seg'] = loss_seg
+ else:
+ semantic_feat = None
+
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(proposal_list[j],
+ gt_bboxes[j],
+ gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ # bbox head forward and loss
+ bbox_results = \
+ self._bbox_forward_train(
+ i, x, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg, semantic_feat)
+ roi_labels = bbox_results['bbox_targets'][0]
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # mask head forward and loss
+ if self.with_mask:
+ # interleaved execution: use regressed bboxes by the box branch
+ # to train the mask branch
+ if self.interleaved:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+ # re-assign and sample 512 RoIs from 512 RoIs
+ sampling_results = []
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(
+ proposal_list[j], gt_bboxes[j],
+ gt_bboxes_ignore[j], gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+ mask_results = self._mask_forward_train(
+ i, x, sampling_results, gt_masks, rcnn_train_cfg,
+ semantic_feat)
+ for name, value in mask_results['loss_mask'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine bboxes (same as Cascade R-CNN)
+ if i < self.num_stages - 1 and not self.interleaved:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation."""
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_bbox_result = {}
+ ms_segm_result = {}
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic_feat)
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(len(p) for p in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score]
+ rois = torch.cat([
+ bbox_head.regress_by_class(rois[i], bbox_label[i],
+ bbox_pred[i], img_metas[i])
+ for i in range(num_imgs)
+ ])
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ bbox_result = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+ ms_bbox_result['ensemble'] = bbox_result
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head[-1].num_classes
+ segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ aug_masks = []
+ mask_roi_extractor = self.mask_roi_extractor[-1]
+ mask_feats = mask_roi_extractor(
+ x[:len(mask_roi_extractor.featmap_strides)], mask_rois)
+ if self.with_semantic and 'mask' in self.semantic_fusion:
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ mask_feats += mask_semantic_feat
+ last_feat = None
+
+ num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+
+ # split batch mask prediction back to each image
+ mask_pred = mask_pred.split(num_bbox_per_img, 0)
+ aug_masks.append(
+ [mask.sigmoid().cpu().numpy() for mask in mask_pred])
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[]
+ for _ in range(self.mask_head[-1].num_classes)])
+ else:
+ aug_mask = [mask[i] for mask in aug_masks]
+ merged_mask = merge_aug_masks(
+ aug_mask, [[img_metas[i]]] * self.num_stages,
+ rcnn_test_cfg)
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_mask, _bboxes[i], det_labels[i],
+ rcnn_test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ segm_results.append(segm_result)
+ ms_segm_result['ensemble'] = segm_results
+
+ if self.with_mask:
+ results = list(
+ zip(ms_bbox_result['ensemble'], ms_segm_result['ensemble']))
+ else:
+ results = ms_bbox_result['ensemble']
+
+ return results
+
+ def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ if self.with_semantic:
+ semantic_feats = [
+ self.semantic_head(feat)[1] for feat in img_feats
+ ]
+ else:
+ semantic_feats = [None] * len(img_metas)
+
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta, semantic in zip(img_feats, img_metas, semantic_feats):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i, x, rois, semantic_feat=semantic)
+ ms_scores.append(bbox_results['cls_score'])
+
+ if i < self.num_stages - 1:
+ bbox_label = bbox_results['cls_score'].argmax(dim=1)
+ rois = bbox_head.regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ bbox_result = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[[]
+ for _ in range(self.mask_head[-1].num_classes)]
+ ]
+ else:
+ aug_masks = []
+ aug_img_metas = []
+ for x, img_meta, semantic in zip(img_feats, img_metas,
+ semantic_feats):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ mask_feats = self.mask_roi_extractor[-1](
+ x[:len(self.mask_roi_extractor[-1].featmap_strides)],
+ mask_rois)
+ if self.with_semantic:
+ semantic_feat = semantic
+ mask_semantic_feat = self.semantic_roi_extractor(
+ [semantic_feat], mask_rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[
+ -2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats += mask_semantic_feat
+ last_feat = None
+ for i in range(self.num_stages):
+ mask_head = self.mask_head[i]
+ if self.mask_info_flow:
+ mask_pred, last_feat = mask_head(
+ mask_feats, last_feat)
+ else:
+ mask_pred = mask_head(mask_feats)
+ aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+ aug_img_metas.append(img_meta)
+ merged_masks = merge_aug_masks(aug_masks, aug_img_metas,
+ self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head[-1].get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return [(bbox_result, segm_result)]
+ else:
+ return [bbox_result]
diff --git a/mmdet/models/roi_heads/mask_heads/__init__.py b/mmdet/models/roi_heads/mask_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c8a76a76f275287006a5e9bbb52b9de962b627c
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/__init__.py
@@ -0,0 +1,18 @@
+from .coarse_mask_head import CoarseMaskHead
+from .fcn_mask_head import FCNMaskHead
+from .fcn_occmask_head import FCNOccMaskHead
+from .feature_relay_head import FeatureRelayHead
+from .fused_semantic_head import FusedSemanticHead
+from .global_context_head import GlobalContextHead
+from .grid_head import GridHead
+from .htc_mask_head import HTCMaskHead
+from .mask_point_head import MaskPointHead
+from .maskiou_head import MaskIoUHead
+from .scnet_mask_head import SCNetMaskHead
+from .scnet_semantic_head import SCNetSemanticHead
+
+__all__ = [
+ 'FCNMaskHead', 'FCNOccMaskHead', 'HTCMaskHead', 'FusedSemanticHead', 'GridHead',
+ 'MaskIoUHead', 'CoarseMaskHead', 'MaskPointHead', 'SCNetMaskHead',
+ 'SCNetSemanticHead', 'GlobalContextHead', 'FeatureRelayHead'
+]
diff --git a/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d665dfff83855e6db3866c681559ccdef09f9999
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/coarse_mask_head.py
@@ -0,0 +1,91 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule, Linear, constant_init, xavier_init
+from mmcv.runner import auto_fp16
+
+from mmdet.models.builder import HEADS
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class CoarseMaskHead(FCNMaskHead):
+ """Coarse mask head used in PointRend.
+
+ Compared with standard ``FCNMaskHead``, ``CoarseMaskHead`` will downsample
+ the input feature map instead of upsample it.
+
+ Args:
+ num_convs (int): Number of conv layers in the head. Default: 0.
+ num_fcs (int): Number of fc layers in the head. Default: 2.
+ fc_out_channels (int): Number of output channels of fc layer.
+ Default: 1024.
+ downsample_factor (int): The factor that feature map is downsampled by.
+ Default: 2.
+ """
+
+ def __init__(self,
+ num_convs=0,
+ num_fcs=2,
+ fc_out_channels=1024,
+ downsample_factor=2,
+ *arg,
+ **kwarg):
+ super(CoarseMaskHead, self).__init__(
+ *arg, num_convs=num_convs, upsample_cfg=dict(type=None), **kwarg)
+ self.num_fcs = num_fcs
+ assert self.num_fcs > 0
+ self.fc_out_channels = fc_out_channels
+ self.downsample_factor = downsample_factor
+ assert self.downsample_factor >= 1
+ # remove conv_logit
+ delattr(self, 'conv_logits')
+
+ if downsample_factor > 1:
+ downsample_in_channels = (
+ self.conv_out_channels
+ if self.num_convs > 0 else self.in_channels)
+ self.downsample_conv = ConvModule(
+ downsample_in_channels,
+ self.conv_out_channels,
+ kernel_size=downsample_factor,
+ stride=downsample_factor,
+ padding=0,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ else:
+ self.downsample_conv = None
+
+ self.output_size = (self.roi_feat_size[0] // downsample_factor,
+ self.roi_feat_size[1] // downsample_factor)
+ self.output_area = self.output_size[0] * self.output_size[1]
+
+ last_layer_dim = self.conv_out_channels * self.output_area
+
+ self.fcs = nn.ModuleList()
+ for i in range(num_fcs):
+ fc_in_channels = (
+ last_layer_dim if i == 0 else self.fc_out_channels)
+ self.fcs.append(Linear(fc_in_channels, self.fc_out_channels))
+ last_layer_dim = self.fc_out_channels
+ output_channels = self.num_classes * self.output_area
+ self.fc_logits = Linear(last_layer_dim, output_channels)
+
+ def init_weights(self):
+ for m in self.fcs.modules():
+ if isinstance(m, nn.Linear):
+ xavier_init(m)
+ constant_init(self.fc_logits, 0.001)
+
+ @auto_fp16()
+ def forward(self, x):
+ for conv in self.convs:
+ x = conv(x)
+
+ if self.downsample_conv is not None:
+ x = self.downsample_conv(x)
+
+ x = x.flatten(1)
+ for fc in self.fcs:
+ x = self.relu(fc(x))
+ mask_pred = self.fc_logits(x).view(
+ x.size(0), self.num_classes, *self.output_size)
+ return mask_pred
diff --git a/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..af12cb06786ce17df331ac74e41b563b294387c0
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/fcn_mask_head.py
@@ -0,0 +1,531 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, ConvModule, build_upsample_layer
+from mmcv.ops.carafe import CARAFEPack
+from mmcv.runner import auto_fp16, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.core import mask_target
+from mmdet.models.builder import HEADS, build_loss
+
+BYTES_PER_FLOAT = 4
+# TODO: This memory limit may be too much or too little. It would be better to
+# determine it based on available resources.
+GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
+
+
+@HEADS.register_module()
+class FCNMaskHead(nn.Module):
+
+ def __init__(self,
+ num_convs=4,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ conv_out_channels=256,
+ num_classes=80,
+ class_agnostic=False,
+ upsample_cfg=dict(type='deconv', scale_factor=2),
+ conv_cfg=None,
+ norm_cfg=None,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
+ super(FCNMaskHead, self).__init__()
+ self.upsample_cfg = upsample_cfg.copy()
+ if self.upsample_cfg['type'] not in [
+ None, 'deconv', 'nearest', 'bilinear', 'carafe'
+ ]:
+ raise ValueError(
+ f'Invalid upsample method {self.upsample_cfg["type"]}, '
+ 'accepted methods are "deconv", "nearest", "bilinear", '
+ '"carafe"')
+ self.num_convs = num_convs
+ # WARN: roi_feat_size is reserved and not used
+ self.roi_feat_size = _pair(roi_feat_size)
+ self.in_channels = in_channels
+ self.conv_kernel_size = conv_kernel_size
+ self.conv_out_channels = conv_out_channels
+ self.upsample_method = self.upsample_cfg.get('type')
+ self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
+ self.num_classes = num_classes
+ self.class_agnostic = class_agnostic
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+ self.loss_mask = build_loss(loss_mask)
+
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+ upsample_in_channels = (
+ self.conv_out_channels if self.num_convs > 0 else in_channels)
+ upsample_cfg_ = self.upsample_cfg.copy()
+ if self.upsample_method is None:
+ self.upsample = None
+ elif self.upsample_method == 'deconv':
+ upsample_cfg_.update(
+ in_channels=upsample_in_channels,
+ out_channels=self.conv_out_channels,
+ kernel_size=self.scale_factor,
+ stride=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ elif self.upsample_method == 'carafe':
+ upsample_cfg_.update(
+ channels=upsample_in_channels, scale_factor=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample_method == 'nearest' else False)
+ upsample_cfg_.update(
+ scale_factor=self.scale_factor,
+ mode=self.upsample_method,
+ align_corners=align_corners)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+
+ out_channels = 1 if self.class_agnostic else self.num_classes
+ logits_in_channel = (
+ self.conv_out_channels
+ if self.upsample_method == 'deconv' else upsample_in_channels)
+ self.conv_logits = Conv2d(logits_in_channel, out_channels, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.debug_imgs = None
+
+ def init_weights(self):
+ for m in [self.upsample, self.conv_logits]:
+ if m is None:
+ continue
+ elif isinstance(m, CARAFEPack):
+ m.init_weights()
+ else:
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.constant_(m.bias, 0)
+
+ @auto_fp16()
+ def forward(self, x):
+ for conv in self.convs:
+ x = conv(x)
+ if self.upsample is not None:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ mask_pred = self.conv_logits(x)
+ return mask_pred
+
+ def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
+ gt_masks, rcnn_train_cfg)
+ return mask_targets
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, mask_targets, labels):
+ """
+ Example:
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> # There are lots of variations depending on the configuration
+ >>> self = FCNMaskHead(num_classes=C, num_convs=1)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> sf = self.scale_factor
+ >>> labels = torch.randint(0, C, size=(N,))
+ >>> # With the default properties the mask targets should indicate
+ >>> # a (potentially soft) single-class label
+ >>> mask_targets = torch.rand(N, H * sf, W * sf)
+ >>> loss = self.loss(mask_pred, mask_targets, labels)
+ >>> print('loss = {!r}'.format(loss))
+ """
+ loss = dict()
+ if mask_pred.size(0) == 0:
+ loss_mask = mask_pred.sum()
+ else:
+ if self.class_agnostic:
+ loss_mask = self.loss_mask(mask_pred, mask_targets,
+ torch.zeros_like(labels))
+ else:
+ #print(mask_pred[:,0:1].shape, mask_targets[0::2].shape, labels.shape)
+ loss_mask_vis = self.loss_mask(mask_pred[:,0:1], mask_targets[0::2], labels)
+ loss_mask_full = self.loss_mask(mask_pred[:,1:2], mask_targets[1::2], labels)
+ loss['loss_mask_vis'] = loss_mask_vis
+ loss['loss_mask_full'] = loss_mask_full
+ return loss
+
+ def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, scale_factor, rescale):
+ """Get segmentation masks from mask_pred and bboxes.
+ Args:
+ mask_pred (Tensor or ndarray): shape (n, #class, h, w).
+ For single-scale testing, mask_pred is the direct output of
+ model, whose type is Tensor, while for multi-scale testing,
+ it will be converted to numpy array outside of this method.
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+ scale_factor(float | Tensor): If ``rescale is True``, box
+ coordinates are divided by this scale factor to fit
+ ``ori_shape``.
+ rescale (bool): If True, the resulting masks will be rescaled to
+ ``ori_shape``.
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+ Example:
+ >>> import mmcv
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> self = FCNMaskHead(num_classes=C, num_convs=0)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> # Each input is associated with some bounding box
+ >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
+ >>> det_labels = torch.randint(0, C, size=(N,))
+ >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
+ >>> ori_shape = (H * 4, W * 4)
+ >>> scale_factor = torch.FloatTensor((1, 1))
+ >>> rescale = False
+ >>> # Encoded masks are a list for each category.
+ >>> encoded_masks = self.get_seg_masks(
+ >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
+ >>> scale_factor, rescale
+ >>> )
+ >>> assert len(encoded_masks) == C
+ >>> assert sum(list(map(len, encoded_masks))) == N
+ """
+ if isinstance(mask_pred, torch.Tensor):
+ mask_pred = mask_pred.sigmoid()
+ else:
+ mask_pred = det_bboxes.new_tensor(mask_pred)
+
+ device = mask_pred.device
+ cls_segms = [[] for _ in range(self.num_classes)
+ ] # BG is not included in num_classes
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ if isinstance(scale_factor, float):
+ img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
+ else:
+ w_scale, h_scale = scale_factor[0], scale_factor[1]
+ img_h = np.round(ori_shape[0] * h_scale.item()).astype(
+ np.int32)
+ img_w = np.round(ori_shape[1] * w_scale.item()).astype(
+ np.int32)
+ scale_factor = 1.0
+
+ if not isinstance(scale_factor, (float, torch.Tensor)):
+ scale_factor = bboxes.new_tensor(scale_factor)
+ bboxes = bboxes / scale_factor
+
+ if torch.onnx.is_in_onnx_export():
+ # TODO: Remove after F.grid_sample is supported.
+ from torchvision.models.detection.roi_heads \
+ import paste_masks_in_image
+ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
+ thr = rcnn_test_cfg.get('mask_thr_binary', 0)
+ if thr > 0:
+ masks = masks >= thr
+ return masks
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if device.type == 'cpu':
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True, so that it performs minimal number of
+ # operations.
+ num_chunks = N
+ else:
+ # GPU benefits from parallelism for larger chunks,
+ # but may have memory issue
+ num_chunks = int(
+ np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
+ assert (num_chunks <=
+ N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+ chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
+
+ threshold = rcnn_test_cfg.mask_thr_binary
+ im_mask = torch.zeros(
+ N,
+ img_h,
+ img_w,
+ device=device,
+ dtype=torch.bool if threshold >= 0 else torch.uint8)
+
+ if not self.class_agnostic:
+ mask_pred = mask_pred[range(N), labels][:, None]
+
+ for inds in chunks:
+ masks_chunk, spatial_inds = _do_paste_mask(
+ mask_pred[inds],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+
+ if threshold >= 0:
+ masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
+ else:
+ # for visualization and debugging
+ masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+ im_mask[(inds, ) + spatial_inds] = masks_chunk
+
+ for i in range(N):
+ cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
+ return cls_segms
+
+ def get_seg_masks1(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, scale_factor, rescale):
+ """Get segmentation masks from mask_pred and bboxes.
+
+ Args:
+ mask_pred (Tensor or ndarray): shape (n, #class, h, w).
+ For single-scale testing, mask_pred is the direct output of
+ model, whose type is Tensor, while for multi-scale testing,
+ it will be converted to numpy array outside of this method.
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+ scale_factor(float | Tensor): If ``rescale is True``, box
+ coordinates are divided by this scale factor to fit
+ ``ori_shape``.
+ rescale (bool): If True, the resulting masks will be rescaled to
+ ``ori_shape``.
+
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+
+ Example:
+ >>> import mmcv
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> self = FCNMaskHead(num_classes=C, num_convs=0)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> # Each input is associated with some bounding box
+ >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
+ >>> det_labels = torch.randint(0, C, size=(N,))
+ >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
+ >>> ori_shape = (H * 4, W * 4)
+ >>> scale_factor = torch.FloatTensor((1, 1))
+ >>> rescale = False
+ >>> # Encoded masks are a list for each category.
+ >>> encoded_masks = self.get_seg_masks(
+ >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
+ >>> scale_factor, rescale
+ >>> )
+ >>> assert len(encoded_masks) == C
+ >>> assert sum(list(map(len, encoded_masks))) == N
+ """
+ if isinstance(mask_pred, torch.Tensor):
+ mask_pred = mask_pred.sigmoid()
+ else:
+ mask_pred = det_bboxes.new_tensor(mask_pred)
+
+ device = mask_pred.device
+ cls_segms = [[] for _ in range(self.num_classes)
+ ] # BG is not included in num_classes
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+ labels = torch.cat((labels, torch.tensor(([1]))))
+ bboxes = torch.cat((bboxes, bboxes))
+ #print(labels,torch.tensor(([1])))
+ #asas
+
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ if isinstance(scale_factor, float):
+ img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
+ else:
+ w_scale, h_scale = scale_factor[0], scale_factor[1]
+ img_h = np.round(ori_shape[0] * h_scale.item()).astype(
+ np.int32)
+ img_w = np.round(ori_shape[1] * w_scale.item()).astype(
+ np.int32)
+ scale_factor = 1.0
+
+ if not isinstance(scale_factor, (float, torch.Tensor)):
+ scale_factor = bboxes.new_tensor(scale_factor)
+ bboxes = bboxes / scale_factor
+
+ if torch.onnx.is_in_onnx_export():
+ # TODO: Remove after F.grid_sample is supported.
+ from torchvision.models.detection.roi_heads \
+ import paste_masks_in_image
+ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
+ thr = rcnn_test_cfg.get('mask_thr_binary', 0)
+ if thr > 0:
+ masks = masks >= thr
+ return masks
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if device.type == 'cpu':
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True, so that it performs minimal number of
+ # operations.
+ num_chunks = N
+ else:
+ # GPU benefits from parallelism for larger chunks,
+ # but may have memory issue
+ num_chunks = int(
+ np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
+ assert (num_chunks <=
+ N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+ chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
+
+ threshold = rcnn_test_cfg.mask_thr_binary
+ im_mask = torch.zeros(
+ N,
+ img_h,
+ img_w,
+ device=device,
+ dtype=torch.bool if threshold >= 0 else torch.uint8)
+
+ if not self.class_agnostic:
+ mask_pred = mask_pred[range(N), labels][:, None]
+ #print('-----------------------------')
+ #print(chunks)
+
+ for inds in chunks:
+ #print(mask_pred[inds].shape, bboxes[inds].shape)
+ masks_chunk, spatial_inds = _do_paste_mask(
+ mask_pred[0:1],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+ masks_chunk_occ, spatial_inds_occ = _do_paste_mask(
+ mask_pred[1:2],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+
+
+ if threshold >= 0:
+ masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
+ masks_chunk_occ = (masks_chunk_occ >= threshold).to(dtype=torch.bool)
+ else:
+ # for visualization and debugging
+ masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+ im_mask[([0], ) + spatial_inds] = masks_chunk
+ im_mask[([1], ) + spatial_inds] = masks_chunk_occ
+
+
+ for i in range(N):
+ cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
+ #print(cls_segms)
+ return cls_segms
+
+
+def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
+ """Paste instance masks according to boxes.
+
+ This implementation is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ Args:
+ masks (Tensor): N, 1, H, W
+ boxes (Tensor): N, 4
+ img_h (int): Height of the image to be pasted.
+ img_w (int): Width of the image to be pasted.
+ skip_empty (bool): Only paste masks within the region that
+ tightly bound all boxes, and returns the results this region only.
+ An important optimization for CPU.
+
+ Returns:
+ tuple: (Tensor, tuple). The first item is mask tensor, the second one
+ is the slice object.
+ If skip_empty == False, the whole image will be pasted. It will
+ return a mask of shape (N, img_h, img_w) and an empty tuple.
+ If skip_empty == True, only area around the mask will be pasted.
+ A mask of shape (N, h', w') and its start and end coordinates
+ in the original image will be returned.
+ """
+ # On GPU, paste all masks together (up to chunk size)
+ # by using the entire image to sample the masks
+ # Compared to pasting them one by one,
+ # this has more operations but is faster on COCO-scale dataset.
+ device = masks.device
+ if skip_empty:
+ x0_int, y0_int = torch.clamp(
+ boxes.min(dim=0).values.floor()[:2] - 1,
+ min=0).to(dtype=torch.int32)
+ x1_int = torch.clamp(
+ boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
+ y1_int = torch.clamp(
+ boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
+ else:
+ x0_int, y0_int = 0, 0
+ x1_int, y1_int = img_w, img_h
+ x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
+
+ N = masks.shape[0]
+
+ img_y = torch.arange(
+ y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
+ img_x = torch.arange(
+ x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
+ img_y = (img_y - y0) / (y1 - y0) * 2 - 1
+ img_x = (img_x - x0) / (x1 - x0) * 2 - 1
+ # img_x, img_y have shapes (N, w), (N, h)
+ if torch.isinf(img_x).any():
+ inds = torch.where(torch.isinf(img_x))
+ img_x[inds] = 0
+ if torch.isinf(img_y).any():
+ inds = torch.where(torch.isinf(img_y))
+ img_y[inds] = 0
+
+ gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
+ gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+ if torch.onnx.is_in_onnx_export():
+ raise RuntimeError(
+ 'Exporting F.grid_sample from Pytorch to ONNX is not supported.')
+ img_masks = F.grid_sample(
+ masks.to(dtype=torch.float32), grid, align_corners=False)
+
+ if skip_empty:
+ return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
+ else:
+ return img_masks[:, 0], ()
diff --git a/mmdet/models/roi_heads/mask_heads/fcn_occmask_head.py b/mmdet/models/roi_heads/mask_heads/fcn_occmask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..17953ed183cc5f1cd55af7d3196fe6ffa4aa06db
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/fcn_occmask_head.py
@@ -0,0 +1,570 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import Conv2d, ConvModule, build_upsample_layer
+from mmcv.ops.carafe import CARAFEPack
+from mmcv.runner import auto_fp16, force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.core import mask_target
+from mmdet.models.builder import HEADS, build_loss
+
+BYTES_PER_FLOAT = 4
+# TODO: This memory limit may be too much or too little. It would be better to
+# determine it based on available resources.
+GPU_MEM_LIMIT = 1024**3 # 1 GB memory limit
+
+
+@HEADS.register_module()
+class FCNOccMaskHead(nn.Module):
+
+ def __init__(self,
+ num_convs=4,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ conv_out_channels=256,
+ num_classes=80,
+ class_agnostic=False,
+ upsample_cfg=dict(type='deconv', scale_factor=2),
+ conv_cfg=None,
+ norm_cfg=None,
+ loss_mask=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
+ super(FCNOccMaskHead, self).__init__()
+ self.upsample_cfg = upsample_cfg.copy()
+ if self.upsample_cfg['type'] not in [
+ None, 'deconv', 'nearest', 'bilinear', 'carafe'
+ ]:
+ raise ValueError(
+ f'Invalid upsample method {self.upsample_cfg["type"]}, '
+ 'accepted methods are "deconv", "nearest", "bilinear", '
+ '"carafe"')
+ self.num_convs = num_convs
+ # WARN: roi_feat_size is reserved and not used
+ self.roi_feat_size = _pair(roi_feat_size)
+ self.in_channels = in_channels
+ self.conv_kernel_size = conv_kernel_size
+ self.conv_out_channels = conv_out_channels
+ self.upsample_method = self.upsample_cfg.get('type')
+ self.scale_factor = self.upsample_cfg.pop('scale_factor', None)
+ self.num_classes = num_classes
+ self.class_agnostic = class_agnostic
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+ self.loss_mask = build_loss(loss_mask)
+
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ if i ==0:
+ in_channels_change = in_channels*2
+ else:
+ in_channels_change = in_channels
+
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs.append(
+ ConvModule(
+ in_channels_change,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+
+ self.convs_occluder = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs_occluder.append(
+ ConvModule(
+ in_channels,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ padding=padding,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg))
+
+ upsample_in_channels = (
+ self.conv_out_channels if self.num_convs > 0 else in_channels)
+ upsample_cfg_ = self.upsample_cfg.copy()
+ if self.upsample_method is None:
+ self.upsample = None
+ elif self.upsample_method == 'deconv':
+ upsample_cfg_.update(
+ in_channels=upsample_in_channels,
+ out_channels=self.conv_out_channels,
+ kernel_size=self.scale_factor,
+ stride=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ elif self.upsample_method == 'carafe':
+ upsample_cfg_.update(
+ channels=upsample_in_channels, scale_factor=self.scale_factor)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+ else:
+ # suppress warnings
+ align_corners = (None
+ if self.upsample_method == 'nearest' else False)
+ upsample_cfg_.update(
+ scale_factor=self.scale_factor,
+ mode=self.upsample_method,
+ align_corners=align_corners)
+ self.upsample = build_upsample_layer(upsample_cfg_)
+
+ out_channels = 1 if self.class_agnostic else self.num_classes
+ logits_in_channel = (
+ self.conv_out_channels
+ if self.upsample_method == 'deconv' else upsample_in_channels)
+ self.conv_logits = Conv2d(logits_in_channel, out_channels, 1)
+ self.conv_logits_occluder = Conv2d(logits_in_channel, out_channels, 1)
+ self.relu = nn.ReLU(inplace=True)
+ self.debug_imgs = None
+
+ def init_weights(self):
+ for m in [self.upsample, self.conv_logits]:
+ if m is None:
+ continue
+ elif isinstance(m, CARAFEPack):
+ m.init_weights()
+ else:
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ nn.init.constant_(m.bias, 0)
+
+ @auto_fp16()
+ def forward(self, x):
+ y = x.clone()
+ for conv in self.convs_occluder:
+ y = conv(y)
+ x = torch.cat((x, y), 1)
+ for conv in self.convs:
+ x = conv(x)
+ if self.upsample is not None:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ if self.upsample is not None:
+ y = self.upsample(y)
+ if self.upsample_method == 'deconv':
+ y = self.relu(y)
+ mask_pred = self.conv_logits(x)
+ mask_occluder_pred = self.conv_logits_occluder(y)
+ return mask_pred, mask_occluder_pred
+
+ def get_targets(self, sampling_results, gt_masks, rcnn_train_cfg):
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ mask_targets = mask_target(pos_proposals, pos_assigned_gt_inds,
+ gt_masks, rcnn_train_cfg)
+ return mask_targets
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, mask_targets, labels):
+ """
+ Example:
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> # There are lots of variations depending on the configuration
+ >>> self = FCNMaskHead(num_classes=C, num_convs=1)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> sf = self.scale_factor
+ >>> labels = torch.randint(0, C, size=(N,))
+ >>> # With the default properties the mask targets should indicate
+ >>> # a (potentially soft) single-class label
+ >>> mask_targets = torch.rand(N, H * sf, W * sf)
+ >>> loss = self.loss(mask_pred, mask_targets, labels)
+ >>> print('loss = {!r}'.format(loss))
+ """
+ mask_full_pred, mask_occ_pred = mask_pred
+ loss = dict()
+ if mask_full_pred.size(0) == 0:
+ loss_mask_vis = mask_full_pred.sum()
+ else:
+ if self.class_agnostic:
+ loss_mask = self.loss_mask(mask_full_pred, mask_targets,
+ torch.zeros_like(labels))
+ else:
+ #print(mask_pred[:,0:1].shape, mask_targets[0::2].shape, labels.shape)
+ loss_mask_vis = self.loss_mask(mask_full_pred[:,0:1], mask_targets[0::2], labels)
+ loss['loss_mask_vis'] = loss_mask_vis
+
+ if mask_occ_pred.size(0) == 0:
+ loss_mask = mask_occ_pred.sum()
+ else:
+ if self.class_agnostic:
+ loss_mask = self.loss_mask(mask_occ_pred, mask_targets,
+ torch.zeros_like(labels))
+ else:
+ loss_mask_occ = self.loss_mask(mask_occ_pred[:,0:1], mask_targets[1::2], labels)
+ loss['loss_mask_occ'] = loss_mask_occ
+ return loss
+
+ def get_seg_masks(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, scale_factor, rescale):
+ """Get segmentation masks from mask_pred and bboxes.
+ Args:
+ mask_pred (Tensor or ndarray): shape (n, #class, h, w).
+ For single-scale testing, mask_pred is the direct output of
+ model, whose type is Tensor, while for multi-scale testing,
+ it will be converted to numpy array outside of this method.
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+ scale_factor(float | Tensor): If ``rescale is True``, box
+ coordinates are divided by this scale factor to fit
+ ``ori_shape``.
+ rescale (bool): If True, the resulting masks will be rescaled to
+ ``ori_shape``.
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+ Example:
+ >>> import mmcv
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> self = FCNMaskHead(num_classes=C, num_convs=0)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> # Each input is associated with some bounding box
+ >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
+ >>> det_labels = torch.randint(0, C, size=(N,))
+ >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
+ >>> ori_shape = (H * 4, W * 4)
+ >>> scale_factor = torch.FloatTensor((1, 1))
+ >>> rescale = False
+ >>> # Encoded masks are a list for each category.
+ >>> encoded_masks = self.get_seg_masks(
+ >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
+ >>> scale_factor, rescale
+ >>> )
+ >>> assert len(encoded_masks) == C
+ >>> assert sum(list(map(len, encoded_masks))) == N
+ """
+ if isinstance(mask_pred, torch.Tensor):
+ mask_pred = mask_pred.sigmoid()
+ else:
+ mask_pred = det_bboxes.new_tensor(mask_pred)
+
+ device = mask_pred.device
+ cls_segms = [[] for _ in range(self.num_classes)
+ ] # BG is not included in num_classes
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ if isinstance(scale_factor, float):
+ img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
+ else:
+ w_scale, h_scale = scale_factor[0], scale_factor[1]
+ img_h = np.round(ori_shape[0] * h_scale.item()).astype(
+ np.int32)
+ img_w = np.round(ori_shape[1] * w_scale.item()).astype(
+ np.int32)
+ scale_factor = 1.0
+
+ if not isinstance(scale_factor, (float, torch.Tensor)):
+ scale_factor = bboxes.new_tensor(scale_factor)
+ bboxes = bboxes / scale_factor
+
+ if torch.onnx.is_in_onnx_export():
+ # TODO: Remove after F.grid_sample is supported.
+ from torchvision.models.detection.roi_heads \
+ import paste_masks_in_image
+ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
+ thr = rcnn_test_cfg.get('mask_thr_binary', 0)
+ if thr > 0:
+ masks = masks >= thr
+ return masks
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if device.type == 'cpu':
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True, so that it performs minimal number of
+ # operations.
+ num_chunks = N
+ else:
+ # GPU benefits from parallelism for larger chunks,
+ # but may have memory issue
+ num_chunks = int(
+ np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
+ assert (num_chunks <=
+ N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+ chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
+
+ threshold = rcnn_test_cfg.mask_thr_binary
+ im_mask = torch.zeros(
+ N,
+ img_h,
+ img_w,
+ device=device,
+ dtype=torch.bool if threshold >= 0 else torch.uint8)
+
+ if not self.class_agnostic:
+ mask_pred = mask_pred[range(N), labels][:, None]
+
+ for inds in chunks:
+ masks_chunk, spatial_inds = _do_paste_mask(
+ mask_pred[inds],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+
+ if threshold >= 0:
+ masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
+ else:
+ # for visualization and debugging
+ masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+ im_mask[(inds, ) + spatial_inds] = masks_chunk
+
+ for i in range(N):
+ cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
+ return cls_segms
+
+ def get_seg_masks1(self, mask_pred, det_bboxes, det_labels, rcnn_test_cfg,
+ ori_shape, scale_factor, rescale):
+ """Get segmentation masks from mask_pred and bboxes.
+
+ Args:
+ mask_pred (Tensor or ndarray): shape (n, #class, h, w).
+ For single-scale testing, mask_pred is the direct output of
+ model, whose type is Tensor, while for multi-scale testing,
+ it will be converted to numpy array outside of this method.
+ det_bboxes (Tensor): shape (n, 4/5)
+ det_labels (Tensor): shape (n, )
+ rcnn_test_cfg (dict): rcnn testing config
+ ori_shape (Tuple): original image height and width, shape (2,)
+ scale_factor(float | Tensor): If ``rescale is True``, box
+ coordinates are divided by this scale factor to fit
+ ``ori_shape``.
+ rescale (bool): If True, the resulting masks will be rescaled to
+ ``ori_shape``.
+
+ Returns:
+ list[list]: encoded masks. The c-th item in the outer list
+ corresponds to the c-th class. Given the c-th outer list, the
+ i-th item in that inner list is the mask for the i-th box with
+ class label c.
+
+ Example:
+ >>> import mmcv
+ >>> from mmdet.models.roi_heads.mask_heads.fcn_mask_head import * # NOQA
+ >>> N = 7 # N = number of extracted ROIs
+ >>> C, H, W = 11, 32, 32
+ >>> # Create example instance of FCN Mask Head.
+ >>> self = FCNMaskHead(num_classes=C, num_convs=0)
+ >>> inputs = torch.rand(N, self.in_channels, H, W)
+ >>> mask_pred = self.forward(inputs)
+ >>> # Each input is associated with some bounding box
+ >>> det_bboxes = torch.Tensor([[1, 1, 42, 42 ]] * N)
+ >>> det_labels = torch.randint(0, C, size=(N,))
+ >>> rcnn_test_cfg = mmcv.Config({'mask_thr_binary': 0, })
+ >>> ori_shape = (H * 4, W * 4)
+ >>> scale_factor = torch.FloatTensor((1, 1))
+ >>> rescale = False
+ >>> # Encoded masks are a list for each category.
+ >>> encoded_masks = self.get_seg_masks(
+ >>> mask_pred, det_bboxes, det_labels, rcnn_test_cfg, ori_shape,
+ >>> scale_factor, rescale
+ >>> )
+ >>> assert len(encoded_masks) == C
+ >>> assert sum(list(map(len, encoded_masks))) == N
+ """
+ if isinstance(mask_pred, torch.Tensor):
+ mask_pred = mask_pred.sigmoid()
+ else:
+ mask_pred = det_bboxes.new_tensor(mask_pred)
+
+ device = mask_pred.device
+ cls_segms = [[] for _ in range(self.num_classes)
+ ] # BG is not included in num_classes
+ bboxes = det_bboxes[:, :4]
+ labels = det_labels
+ labels = torch.cat((labels, torch.tensor(([1]))))
+ bboxes = torch.cat((bboxes, bboxes))
+ #print(labels,torch.tensor(([1])))
+ #asas
+
+ if rescale:
+ img_h, img_w = ori_shape[:2]
+ else:
+ if isinstance(scale_factor, float):
+ img_h = np.round(ori_shape[0] * scale_factor).astype(np.int32)
+ img_w = np.round(ori_shape[1] * scale_factor).astype(np.int32)
+ else:
+ w_scale, h_scale = scale_factor[0], scale_factor[1]
+ img_h = np.round(ori_shape[0] * h_scale.item()).astype(
+ np.int32)
+ img_w = np.round(ori_shape[1] * w_scale.item()).astype(
+ np.int32)
+ scale_factor = 1.0
+
+ if not isinstance(scale_factor, (float, torch.Tensor)):
+ scale_factor = bboxes.new_tensor(scale_factor)
+ bboxes = bboxes / scale_factor
+
+ if torch.onnx.is_in_onnx_export():
+ # TODO: Remove after F.grid_sample is supported.
+ from torchvision.models.detection.roi_heads \
+ import paste_masks_in_image
+ masks = paste_masks_in_image(mask_pred, bboxes, ori_shape[:2])
+ thr = rcnn_test_cfg.get('mask_thr_binary', 0)
+ if thr > 0:
+ masks = masks >= thr
+ return masks
+
+ N = len(mask_pred)
+ # The actual implementation split the input into chunks,
+ # and paste them chunk by chunk.
+ if device.type == 'cpu':
+ # CPU is most efficient when they are pasted one by one with
+ # skip_empty=True, so that it performs minimal number of
+ # operations.
+ num_chunks = N
+ else:
+ # GPU benefits from parallelism for larger chunks,
+ # but may have memory issue
+ num_chunks = int(
+ np.ceil(N * img_h * img_w * BYTES_PER_FLOAT / GPU_MEM_LIMIT))
+ assert (num_chunks <=
+ N), 'Default GPU_MEM_LIMIT is too small; try increasing it'
+ chunks = torch.chunk(torch.arange(N, device=device), num_chunks)
+
+ threshold = rcnn_test_cfg.mask_thr_binary
+ im_mask = torch.zeros(
+ N,
+ img_h,
+ img_w,
+ device=device,
+ dtype=torch.bool if threshold >= 0 else torch.uint8)
+
+ if not self.class_agnostic:
+ mask_pred = mask_pred[range(N), labels][:, None]
+ #print('-----------------------------')
+ #print(chunks)
+
+ for inds in chunks:
+ #print(mask_pred[inds].shape, bboxes[inds].shape)
+ masks_chunk, spatial_inds = _do_paste_mask(
+ mask_pred[0:1],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+ masks_chunk_occ, spatial_inds_occ = _do_paste_mask(
+ mask_pred[1:2],
+ bboxes[inds],
+ img_h,
+ img_w,
+ skip_empty=device.type == 'cpu')
+
+
+ if threshold >= 0:
+ masks_chunk = (masks_chunk >= threshold).to(dtype=torch.bool)
+ masks_chunk_occ = (masks_chunk_occ >= threshold).to(dtype=torch.bool)
+ else:
+ # for visualization and debugging
+ masks_chunk = (masks_chunk * 255).to(dtype=torch.uint8)
+
+ im_mask[([0], ) + spatial_inds] = masks_chunk
+ im_mask[([1], ) + spatial_inds] = masks_chunk_occ
+
+
+ for i in range(N):
+ cls_segms[labels[i]].append(im_mask[i].detach().cpu().numpy())
+ #print(cls_segms)
+ return cls_segms
+
+
+def _do_paste_mask(masks, boxes, img_h, img_w, skip_empty=True):
+ """Paste instance masks according to boxes.
+
+ This implementation is modified from
+ https://github.com/facebookresearch/detectron2/
+
+ Args:
+ masks (Tensor): N, 1, H, W
+ boxes (Tensor): N, 4
+ img_h (int): Height of the image to be pasted.
+ img_w (int): Width of the image to be pasted.
+ skip_empty (bool): Only paste masks within the region that
+ tightly bound all boxes, and returns the results this region only.
+ An important optimization for CPU.
+
+ Returns:
+ tuple: (Tensor, tuple). The first item is mask tensor, the second one
+ is the slice object.
+ If skip_empty == False, the whole image will be pasted. It will
+ return a mask of shape (N, img_h, img_w) and an empty tuple.
+ If skip_empty == True, only area around the mask will be pasted.
+ A mask of shape (N, h', w') and its start and end coordinates
+ in the original image will be returned.
+ """
+ # On GPU, paste all masks together (up to chunk size)
+ # by using the entire image to sample the masks
+ # Compared to pasting them one by one,
+ # this has more operations but is faster on COCO-scale dataset.
+ device = masks.device
+ if skip_empty:
+ x0_int, y0_int = torch.clamp(
+ boxes.min(dim=0).values.floor()[:2] - 1,
+ min=0).to(dtype=torch.int32)
+ x1_int = torch.clamp(
+ boxes[:, 2].max().ceil() + 1, max=img_w).to(dtype=torch.int32)
+ y1_int = torch.clamp(
+ boxes[:, 3].max().ceil() + 1, max=img_h).to(dtype=torch.int32)
+ else:
+ x0_int, y0_int = 0, 0
+ x1_int, y1_int = img_w, img_h
+ x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
+
+ N = masks.shape[0]
+
+ img_y = torch.arange(
+ y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
+ img_x = torch.arange(
+ x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
+ img_y = (img_y - y0) / (y1 - y0) * 2 - 1
+ img_x = (img_x - x0) / (x1 - x0) * 2 - 1
+ # img_x, img_y have shapes (N, w), (N, h)
+ if torch.isinf(img_x).any():
+ inds = torch.where(torch.isinf(img_x))
+ img_x[inds] = 0
+ if torch.isinf(img_y).any():
+ inds = torch.where(torch.isinf(img_y))
+ img_y[inds] = 0
+
+ gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
+ gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
+ grid = torch.stack([gx, gy], dim=3)
+
+ if torch.onnx.is_in_onnx_export():
+ raise RuntimeError(
+ 'Exporting F.grid_sample from Pytorch to ONNX is not supported.')
+ img_masks = F.grid_sample(
+ masks.to(dtype=torch.float32), grid, align_corners=False)
+
+ if skip_empty:
+ return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
+ else:
+ return img_masks[:, 0], ()
diff --git a/mmdet/models/roi_heads/mask_heads/feature_relay_head.py b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1cfb2ce8631d51e5c465f9bbc4164a37acc4782
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/feature_relay_head.py
@@ -0,0 +1,55 @@
+import torch.nn as nn
+from mmcv.cnn import kaiming_init
+from mmcv.runner import auto_fp16
+
+from mmdet.models.builder import HEADS
+
+
+@HEADS.register_module()
+class FeatureRelayHead(nn.Module):
+ """Feature Relay Head used in `SCNet `_.
+
+ Args:
+ in_channels (int, optional): number of input channels. Default: 256.
+ conv_out_channels (int, optional): number of output channels before
+ classification layer. Default: 256.
+ roi_feat_size (int, optional): roi feat size at box head. Default: 7.
+ scale_factor (int, optional): scale factor to match roi feat size
+ at mask head. Default: 2.
+ """
+
+ def __init__(self,
+ in_channels=1024,
+ out_conv_channels=256,
+ roi_feat_size=7,
+ scale_factor=2):
+ super(FeatureRelayHead, self).__init__()
+ assert isinstance(roi_feat_size, int)
+
+ self.in_channels = in_channels
+ self.out_conv_channels = out_conv_channels
+ self.roi_feat_size = roi_feat_size
+ self.out_channels = (roi_feat_size**2) * out_conv_channels
+ self.scale_factor = scale_factor
+ self.fp16_enabled = False
+
+ self.fc = nn.Linear(self.in_channels, self.out_channels)
+ self.upsample = nn.Upsample(
+ scale_factor=scale_factor, mode='bilinear', align_corners=True)
+
+ def init_weights(self):
+ """Init weights for the head."""
+ kaiming_init(self.fc)
+
+ @auto_fp16()
+ def forward(self, x):
+ """Forward function."""
+ N, in_C = x.shape
+ if N > 0:
+ out_C = self.out_conv_channels
+ out_HW = self.roi_feat_size
+ x = self.fc(x)
+ x = x.reshape(N, out_C, out_HW, out_HW)
+ x = self.upsample(x)
+ return x
+ return None
diff --git a/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..2aa6033eec17a30aeb68c0fdd218d8f0d41157e8
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/fused_semantic_head.py
@@ -0,0 +1,107 @@
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, kaiming_init
+from mmcv.runner import auto_fp16, force_fp32
+
+from mmdet.models.builder import HEADS
+
+
+@HEADS.register_module()
+class FusedSemanticHead(nn.Module):
+ r"""Multi-level fused semantic segmentation head.
+
+ .. code-block:: none
+
+ in_1 -> 1x1 conv ---
+ |
+ in_2 -> 1x1 conv -- |
+ ||
+ in_3 -> 1x1 conv - ||
+ ||| /-> 1x1 conv (mask prediction)
+ in_4 -> 1x1 conv -----> 3x3 convs (*4)
+ | \-> 1x1 conv (feature)
+ in_5 -> 1x1 conv ---
+ """ # noqa: W605
+
+ def __init__(self,
+ num_ins,
+ fusion_level,
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=183,
+ ignore_label=255,
+ loss_weight=0.2,
+ conv_cfg=None,
+ norm_cfg=None):
+ super(FusedSemanticHead, self).__init__()
+ self.num_ins = num_ins
+ self.fusion_level = fusion_level
+ self.num_convs = num_convs
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.num_classes = num_classes
+ self.ignore_label = ignore_label
+ self.loss_weight = loss_weight
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.fp16_enabled = False
+
+ self.lateral_convs = nn.ModuleList()
+ for i in range(self.num_ins):
+ self.lateral_convs.append(
+ ConvModule(
+ self.in_channels,
+ self.in_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ inplace=False))
+
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = self.in_channels if i == 0 else conv_out_channels
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+ self.conv_embedding = ConvModule(
+ conv_out_channels,
+ conv_out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.conv_logits = nn.Conv2d(conv_out_channels, self.num_classes, 1)
+
+ self.criterion = nn.CrossEntropyLoss(ignore_index=ignore_label)
+
+ def init_weights(self):
+ kaiming_init(self.conv_logits)
+
+ @auto_fp16()
+ def forward(self, feats):
+ x = self.lateral_convs[self.fusion_level](feats[self.fusion_level])
+ fused_size = tuple(x.shape[-2:])
+ for i, feat in enumerate(feats):
+ if i != self.fusion_level:
+ feat = F.interpolate(
+ feat, size=fused_size, mode='bilinear', align_corners=True)
+ x += self.lateral_convs[i](feat)
+
+ for i in range(self.num_convs):
+ x = self.convs[i](x)
+
+ mask_pred = self.conv_logits(x)
+ x = self.conv_embedding(x)
+ return mask_pred, x
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def loss(self, mask_pred, labels):
+ labels = labels.squeeze(1).long()
+ loss_semantic_seg = self.criterion(mask_pred, labels)
+ loss_semantic_seg *= self.loss_weight
+ return loss_semantic_seg
diff --git a/mmdet/models/roi_heads/mask_heads/global_context_head.py b/mmdet/models/roi_heads/mask_heads/global_context_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8e8cbca95d69e86ec7a2a1e7ed7f158be1b5753
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/global_context_head.py
@@ -0,0 +1,102 @@
+import torch.nn as nn
+from mmcv.cnn import ConvModule
+from mmcv.runner import auto_fp16, force_fp32
+
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+
+
+@HEADS.register_module()
+class GlobalContextHead(nn.Module):
+ """Global context head used in `SCNet `_.
+
+ Args:
+ num_convs (int, optional): number of convolutional layer in GlbCtxHead.
+ Default: 4.
+ in_channels (int, optional): number of input channels. Default: 256.
+ conv_out_channels (int, optional): number of output channels before
+ classification layer. Default: 256.
+ num_classes (int, optional): number of classes. Default: 80.
+ loss_weight (float, optional): global context loss weight. Default: 1.
+ conv_cfg (dict, optional): config to init conv layer. Default: None.
+ norm_cfg (dict, optional): config to init norm layer. Default: None.
+ conv_to_res (bool, optional): if True, 2 convs will be grouped into
+ 1 `SimplifiedBasicBlock` using a skip connection. Default: False.
+ """
+
+ def __init__(self,
+ num_convs=4,
+ in_channels=256,
+ conv_out_channels=256,
+ num_classes=80,
+ loss_weight=1.0,
+ conv_cfg=None,
+ norm_cfg=None,
+ conv_to_res=False):
+ super(GlobalContextHead, self).__init__()
+ self.num_convs = num_convs
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.num_classes = num_classes
+ self.loss_weight = loss_weight
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.conv_to_res = conv_to_res
+ self.fp16_enabled = False
+
+ if self.conv_to_res:
+ num_res_blocks = num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ in_channels,
+ self.conv_out_channels,
+ num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.num_convs = num_res_blocks
+ else:
+ self.convs = nn.ModuleList()
+ for i in range(self.num_convs):
+ in_channels = self.in_channels if i == 0 else conv_out_channels
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ conv_out_channels,
+ 3,
+ padding=1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg))
+
+ self.pool = nn.AdaptiveAvgPool2d(1)
+ self.fc = nn.Linear(conv_out_channels, num_classes)
+
+ self.criterion = nn.BCEWithLogitsLoss()
+
+ def init_weights(self):
+ """Init weights for the head."""
+ nn.init.normal_(self.fc.weight, 0, 0.01)
+ nn.init.constant_(self.fc.bias, 0)
+
+ @auto_fp16()
+ def forward(self, feats):
+ """Forward function."""
+ x = feats[-1]
+ for i in range(self.num_convs):
+ x = self.convs[i](x)
+ x = self.pool(x)
+
+ # multi-class prediction
+ mc_pred = x.reshape(x.size(0), -1)
+ mc_pred = self.fc(mc_pred)
+
+ return mc_pred, x
+
+ @force_fp32(apply_to=('pred', ))
+ def loss(self, pred, labels):
+ """Loss function."""
+ labels = [lbl.unique() for lbl in labels]
+ targets = pred.new_zeros(pred.size())
+ for i, label in enumerate(labels):
+ targets[i, label] = 1.0
+ loss = self.loss_weight * self.criterion(pred, targets)
+ return loss
diff --git a/mmdet/models/roi_heads/mask_heads/grid_head.py b/mmdet/models/roi_heads/mask_heads/grid_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..83058cbdda934ebfc3a76088e1820848ac01b78b
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/grid_head.py
@@ -0,0 +1,359 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from mmcv.cnn import ConvModule, kaiming_init, normal_init
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class GridHead(nn.Module):
+
+ def __init__(self,
+ grid_points=9,
+ num_convs=8,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_kernel_size=3,
+ point_feat_channels=64,
+ deconv_kernel_size=4,
+ class_agnostic=False,
+ loss_grid=dict(
+ type='CrossEntropyLoss', use_sigmoid=True,
+ loss_weight=15),
+ conv_cfg=None,
+ norm_cfg=dict(type='GN', num_groups=36)):
+ super(GridHead, self).__init__()
+ self.grid_points = grid_points
+ self.num_convs = num_convs
+ self.roi_feat_size = roi_feat_size
+ self.in_channels = in_channels
+ self.conv_kernel_size = conv_kernel_size
+ self.point_feat_channels = point_feat_channels
+ self.conv_out_channels = self.point_feat_channels * self.grid_points
+ self.class_agnostic = class_agnostic
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ if isinstance(norm_cfg, dict) and norm_cfg['type'] == 'GN':
+ assert self.conv_out_channels % norm_cfg['num_groups'] == 0
+
+ assert self.grid_points >= 4
+ self.grid_size = int(np.sqrt(self.grid_points))
+ if self.grid_size * self.grid_size != self.grid_points:
+ raise ValueError('grid_points must be a square number')
+
+ # the predicted heatmap is half of whole_map_size
+ if not isinstance(self.roi_feat_size, int):
+ raise ValueError('Only square RoIs are supporeted in Grid R-CNN')
+ self.whole_map_size = self.roi_feat_size * 4
+
+ # compute point-wise sub-regions
+ self.sub_regions = self.calc_sub_regions()
+
+ self.convs = []
+ for i in range(self.num_convs):
+ in_channels = (
+ self.in_channels if i == 0 else self.conv_out_channels)
+ stride = 2 if i == 0 else 1
+ padding = (self.conv_kernel_size - 1) // 2
+ self.convs.append(
+ ConvModule(
+ in_channels,
+ self.conv_out_channels,
+ self.conv_kernel_size,
+ stride=stride,
+ padding=padding,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg,
+ bias=True))
+ self.convs = nn.Sequential(*self.convs)
+
+ self.deconv1 = nn.ConvTranspose2d(
+ self.conv_out_channels,
+ self.conv_out_channels,
+ kernel_size=deconv_kernel_size,
+ stride=2,
+ padding=(deconv_kernel_size - 2) // 2,
+ groups=grid_points)
+ self.norm1 = nn.GroupNorm(grid_points, self.conv_out_channels)
+ self.deconv2 = nn.ConvTranspose2d(
+ self.conv_out_channels,
+ grid_points,
+ kernel_size=deconv_kernel_size,
+ stride=2,
+ padding=(deconv_kernel_size - 2) // 2,
+ groups=grid_points)
+
+ # find the 4-neighbor of each grid point
+ self.neighbor_points = []
+ grid_size = self.grid_size
+ for i in range(grid_size): # i-th column
+ for j in range(grid_size): # j-th row
+ neighbors = []
+ if i > 0: # left: (i - 1, j)
+ neighbors.append((i - 1) * grid_size + j)
+ if j > 0: # up: (i, j - 1)
+ neighbors.append(i * grid_size + j - 1)
+ if j < grid_size - 1: # down: (i, j + 1)
+ neighbors.append(i * grid_size + j + 1)
+ if i < grid_size - 1: # right: (i + 1, j)
+ neighbors.append((i + 1) * grid_size + j)
+ self.neighbor_points.append(tuple(neighbors))
+ # total edges in the grid
+ self.num_edges = sum([len(p) for p in self.neighbor_points])
+
+ self.forder_trans = nn.ModuleList() # first-order feature transition
+ self.sorder_trans = nn.ModuleList() # second-order feature transition
+ for neighbors in self.neighbor_points:
+ fo_trans = nn.ModuleList()
+ so_trans = nn.ModuleList()
+ for _ in range(len(neighbors)):
+ # each transition module consists of a 5x5 depth-wise conv and
+ # 1x1 conv.
+ fo_trans.append(
+ nn.Sequential(
+ nn.Conv2d(
+ self.point_feat_channels,
+ self.point_feat_channels,
+ 5,
+ stride=1,
+ padding=2,
+ groups=self.point_feat_channels),
+ nn.Conv2d(self.point_feat_channels,
+ self.point_feat_channels, 1)))
+ so_trans.append(
+ nn.Sequential(
+ nn.Conv2d(
+ self.point_feat_channels,
+ self.point_feat_channels,
+ 5,
+ 1,
+ 2,
+ groups=self.point_feat_channels),
+ nn.Conv2d(self.point_feat_channels,
+ self.point_feat_channels, 1)))
+ self.forder_trans.append(fo_trans)
+ self.sorder_trans.append(so_trans)
+
+ self.loss_grid = build_loss(loss_grid)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
+ # TODO: compare mode = "fan_in" or "fan_out"
+ kaiming_init(m)
+ for m in self.modules():
+ if isinstance(m, nn.ConvTranspose2d):
+ normal_init(m, std=0.001)
+ nn.init.constant_(self.deconv2.bias, -np.log(0.99 / 0.01))
+
+ def forward(self, x):
+ assert x.shape[-1] == x.shape[-2] == self.roi_feat_size
+ # RoI feature transformation, downsample 2x
+ x = self.convs(x)
+
+ c = self.point_feat_channels
+ # first-order fusion
+ x_fo = [None for _ in range(self.grid_points)]
+ for i, points in enumerate(self.neighbor_points):
+ x_fo[i] = x[:, i * c:(i + 1) * c]
+ for j, point_idx in enumerate(points):
+ x_fo[i] = x_fo[i] + self.forder_trans[i][j](
+ x[:, point_idx * c:(point_idx + 1) * c])
+
+ # second-order fusion
+ x_so = [None for _ in range(self.grid_points)]
+ for i, points in enumerate(self.neighbor_points):
+ x_so[i] = x[:, i * c:(i + 1) * c]
+ for j, point_idx in enumerate(points):
+ x_so[i] = x_so[i] + self.sorder_trans[i][j](x_fo[point_idx])
+
+ # predicted heatmap with fused features
+ x2 = torch.cat(x_so, dim=1)
+ x2 = self.deconv1(x2)
+ x2 = F.relu(self.norm1(x2), inplace=True)
+ heatmap = self.deconv2(x2)
+
+ # predicted heatmap with original features (applicable during training)
+ if self.training:
+ x1 = x
+ x1 = self.deconv1(x1)
+ x1 = F.relu(self.norm1(x1), inplace=True)
+ heatmap_unfused = self.deconv2(x1)
+ else:
+ heatmap_unfused = heatmap
+
+ return dict(fused=heatmap, unfused=heatmap_unfused)
+
+ def calc_sub_regions(self):
+ """Compute point specific representation regions.
+
+ See Grid R-CNN Plus (https://arxiv.org/abs/1906.05688) for details.
+ """
+ # to make it consistent with the original implementation, half_size
+ # is computed as 2 * quarter_size, which is smaller
+ half_size = self.whole_map_size // 4 * 2
+ sub_regions = []
+ for i in range(self.grid_points):
+ x_idx = i // self.grid_size
+ y_idx = i % self.grid_size
+ if x_idx == 0:
+ sub_x1 = 0
+ elif x_idx == self.grid_size - 1:
+ sub_x1 = half_size
+ else:
+ ratio = x_idx / (self.grid_size - 1) - 0.25
+ sub_x1 = max(int(ratio * self.whole_map_size), 0)
+
+ if y_idx == 0:
+ sub_y1 = 0
+ elif y_idx == self.grid_size - 1:
+ sub_y1 = half_size
+ else:
+ ratio = y_idx / (self.grid_size - 1) - 0.25
+ sub_y1 = max(int(ratio * self.whole_map_size), 0)
+ sub_regions.append(
+ (sub_x1, sub_y1, sub_x1 + half_size, sub_y1 + half_size))
+ return sub_regions
+
+ def get_targets(self, sampling_results, rcnn_train_cfg):
+ # mix all samples (across images) together.
+ pos_bboxes = torch.cat([res.pos_bboxes for res in sampling_results],
+ dim=0).cpu()
+ pos_gt_bboxes = torch.cat(
+ [res.pos_gt_bboxes for res in sampling_results], dim=0).cpu()
+ assert pos_bboxes.shape == pos_gt_bboxes.shape
+
+ # expand pos_bboxes to 2x of original size
+ x1 = pos_bboxes[:, 0] - (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
+ y1 = pos_bboxes[:, 1] - (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
+ x2 = pos_bboxes[:, 2] + (pos_bboxes[:, 2] - pos_bboxes[:, 0]) / 2
+ y2 = pos_bboxes[:, 3] + (pos_bboxes[:, 3] - pos_bboxes[:, 1]) / 2
+ pos_bboxes = torch.stack([x1, y1, x2, y2], dim=-1)
+ pos_bbox_ws = (pos_bboxes[:, 2] - pos_bboxes[:, 0]).unsqueeze(-1)
+ pos_bbox_hs = (pos_bboxes[:, 3] - pos_bboxes[:, 1]).unsqueeze(-1)
+
+ num_rois = pos_bboxes.shape[0]
+ map_size = self.whole_map_size
+ # this is not the final target shape
+ targets = torch.zeros((num_rois, self.grid_points, map_size, map_size),
+ dtype=torch.float)
+
+ # pre-compute interpolation factors for all grid points.
+ # the first item is the factor of x-dim, and the second is y-dim.
+ # for a 9-point grid, factors are like (1, 0), (0.5, 0.5), (0, 1)
+ factors = []
+ for j in range(self.grid_points):
+ x_idx = j // self.grid_size
+ y_idx = j % self.grid_size
+ factors.append((1 - x_idx / (self.grid_size - 1),
+ 1 - y_idx / (self.grid_size - 1)))
+
+ radius = rcnn_train_cfg.pos_radius
+ radius2 = radius**2
+ for i in range(num_rois):
+ # ignore small bboxes
+ if (pos_bbox_ws[i] <= self.grid_size
+ or pos_bbox_hs[i] <= self.grid_size):
+ continue
+ # for each grid point, mark a small circle as positive
+ for j in range(self.grid_points):
+ factor_x, factor_y = factors[j]
+ gridpoint_x = factor_x * pos_gt_bboxes[i, 0] + (
+ 1 - factor_x) * pos_gt_bboxes[i, 2]
+ gridpoint_y = factor_y * pos_gt_bboxes[i, 1] + (
+ 1 - factor_y) * pos_gt_bboxes[i, 3]
+
+ cx = int((gridpoint_x - pos_bboxes[i, 0]) / pos_bbox_ws[i] *
+ map_size)
+ cy = int((gridpoint_y - pos_bboxes[i, 1]) / pos_bbox_hs[i] *
+ map_size)
+
+ for x in range(cx - radius, cx + radius + 1):
+ for y in range(cy - radius, cy + radius + 1):
+ if x >= 0 and x < map_size and y >= 0 and y < map_size:
+ if (x - cx)**2 + (y - cy)**2 <= radius2:
+ targets[i, j, y, x] = 1
+ # reduce the target heatmap size by a half
+ # proposed in Grid R-CNN Plus (https://arxiv.org/abs/1906.05688).
+ sub_targets = []
+ for i in range(self.grid_points):
+ sub_x1, sub_y1, sub_x2, sub_y2 = self.sub_regions[i]
+ sub_targets.append(targets[:, [i], sub_y1:sub_y2, sub_x1:sub_x2])
+ sub_targets = torch.cat(sub_targets, dim=1)
+ sub_targets = sub_targets.to(sampling_results[0].pos_bboxes.device)
+ return sub_targets
+
+ def loss(self, grid_pred, grid_targets):
+ loss_fused = self.loss_grid(grid_pred['fused'], grid_targets)
+ loss_unfused = self.loss_grid(grid_pred['unfused'], grid_targets)
+ loss_grid = loss_fused + loss_unfused
+ return dict(loss_grid=loss_grid)
+
+ def get_bboxes(self, det_bboxes, grid_pred, img_metas):
+ # TODO: refactoring
+ assert det_bboxes.shape[0] == grid_pred.shape[0]
+ det_bboxes = det_bboxes.cpu()
+ cls_scores = det_bboxes[:, [4]]
+ det_bboxes = det_bboxes[:, :4]
+ grid_pred = grid_pred.sigmoid().cpu()
+
+ R, c, h, w = grid_pred.shape
+ half_size = self.whole_map_size // 4 * 2
+ assert h == w == half_size
+ assert c == self.grid_points
+
+ # find the point with max scores in the half-sized heatmap
+ grid_pred = grid_pred.view(R * c, h * w)
+ pred_scores, pred_position = grid_pred.max(dim=1)
+ xs = pred_position % w
+ ys = pred_position // w
+
+ # get the position in the whole heatmap instead of half-sized heatmap
+ for i in range(self.grid_points):
+ xs[i::self.grid_points] += self.sub_regions[i][0]
+ ys[i::self.grid_points] += self.sub_regions[i][1]
+
+ # reshape to (num_rois, grid_points)
+ pred_scores, xs, ys = tuple(
+ map(lambda x: x.view(R, c), [pred_scores, xs, ys]))
+
+ # get expanded pos_bboxes
+ widths = (det_bboxes[:, 2] - det_bboxes[:, 0]).unsqueeze(-1)
+ heights = (det_bboxes[:, 3] - det_bboxes[:, 1]).unsqueeze(-1)
+ x1 = (det_bboxes[:, 0, None] - widths / 2)
+ y1 = (det_bboxes[:, 1, None] - heights / 2)
+ # map the grid point to the absolute coordinates
+ abs_xs = (xs.float() + 0.5) / w * widths + x1
+ abs_ys = (ys.float() + 0.5) / h * heights + y1
+
+ # get the grid points indices that fall on the bbox boundaries
+ x1_inds = [i for i in range(self.grid_size)]
+ y1_inds = [i * self.grid_size for i in range(self.grid_size)]
+ x2_inds = [
+ self.grid_points - self.grid_size + i
+ for i in range(self.grid_size)
+ ]
+ y2_inds = [(i + 1) * self.grid_size - 1 for i in range(self.grid_size)]
+
+ # voting of all grid points on some boundary
+ bboxes_x1 = (abs_xs[:, x1_inds] * pred_scores[:, x1_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, x1_inds].sum(dim=1, keepdim=True))
+ bboxes_y1 = (abs_ys[:, y1_inds] * pred_scores[:, y1_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, y1_inds].sum(dim=1, keepdim=True))
+ bboxes_x2 = (abs_xs[:, x2_inds] * pred_scores[:, x2_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, x2_inds].sum(dim=1, keepdim=True))
+ bboxes_y2 = (abs_ys[:, y2_inds] * pred_scores[:, y2_inds]).sum(
+ dim=1, keepdim=True) / (
+ pred_scores[:, y2_inds].sum(dim=1, keepdim=True))
+
+ bbox_res = torch.cat(
+ [bboxes_x1, bboxes_y1, bboxes_x2, bboxes_y2, cls_scores], dim=1)
+ bbox_res[:, [0, 2]].clamp_(min=0, max=img_metas[0]['img_shape'][1])
+ bbox_res[:, [1, 3]].clamp_(min=0, max=img_metas[0]['img_shape'][0])
+
+ return bbox_res
diff --git a/mmdet/models/roi_heads/mask_heads/htc_mask_head.py b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..330b778ebad8d48d55d09ddd42baa70ec10ae463
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/htc_mask_head.py
@@ -0,0 +1,43 @@
+from mmcv.cnn import ConvModule
+
+from mmdet.models.builder import HEADS
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class HTCMaskHead(FCNMaskHead):
+
+ def __init__(self, with_conv_res=True, *args, **kwargs):
+ super(HTCMaskHead, self).__init__(*args, **kwargs)
+ self.with_conv_res = with_conv_res
+ if self.with_conv_res:
+ self.conv_res = ConvModule(
+ self.conv_out_channels,
+ self.conv_out_channels,
+ 1,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+
+ def init_weights(self):
+ super(HTCMaskHead, self).init_weights()
+ if self.with_conv_res:
+ self.conv_res.init_weights()
+
+ def forward(self, x, res_feat=None, return_logits=True, return_feat=True):
+ if res_feat is not None:
+ assert self.with_conv_res
+ res_feat = self.conv_res(res_feat)
+ x = x + res_feat
+ for conv in self.convs:
+ x = conv(x)
+ res_feat = x
+ outs = []
+ if return_logits:
+ x = self.upsample(x)
+ if self.upsample_method == 'deconv':
+ x = self.relu(x)
+ mask_pred = self.conv_logits(x)
+ outs.append(mask_pred)
+ if return_feat:
+ outs.append(res_feat)
+ return outs if len(outs) > 1 else outs[0]
diff --git a/mmdet/models/roi_heads/mask_heads/mask_point_head.py b/mmdet/models/roi_heads/mask_heads/mask_point_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb92903a9488a44b984a489a354d838cc88f8ad4
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/mask_point_head.py
@@ -0,0 +1,300 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend/point_head/point_head.py # noqa
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import ConvModule, normal_init
+from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class MaskPointHead(nn.Module):
+ """A mask point head use in PointRend.
+
+ ``MaskPointHead`` use shared multi-layer perceptron (equivalent to
+ nn.Conv1d) to predict the logit of input points. The fine-grained feature
+ and coarse feature will be concatenate together for predication.
+
+ Args:
+ num_fcs (int): Number of fc layers in the head. Default: 3.
+ in_channels (int): Number of input channels. Default: 256.
+ fc_channels (int): Number of fc channels. Default: 256.
+ num_classes (int): Number of classes for logits. Default: 80.
+ class_agnostic (bool): Whether use class agnostic classification.
+ If so, the output channels of logits will be 1. Default: False.
+ coarse_pred_each_layer (bool): Whether concatenate coarse feature with
+ the output of each fc layer. Default: True.
+ conv_cfg (dict | None): Dictionary to construct and config conv layer.
+ Default: dict(type='Conv1d'))
+ norm_cfg (dict | None): Dictionary to construct and config norm layer.
+ Default: None.
+ loss_point (dict): Dictionary to construct and config loss layer of
+ point head. Default: dict(type='CrossEntropyLoss', use_mask=True,
+ loss_weight=1.0).
+ """
+
+ def __init__(self,
+ num_classes,
+ num_fcs=3,
+ in_channels=256,
+ fc_channels=256,
+ class_agnostic=False,
+ coarse_pred_each_layer=True,
+ conv_cfg=dict(type='Conv1d'),
+ norm_cfg=None,
+ act_cfg=dict(type='ReLU'),
+ loss_point=dict(
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0)):
+ super().__init__()
+ self.num_fcs = num_fcs
+ self.in_channels = in_channels
+ self.fc_channels = fc_channels
+ self.num_classes = num_classes
+ self.class_agnostic = class_agnostic
+ self.coarse_pred_each_layer = coarse_pred_each_layer
+ self.conv_cfg = conv_cfg
+ self.norm_cfg = norm_cfg
+ self.loss_point = build_loss(loss_point)
+
+ fc_in_channels = in_channels + num_classes
+ self.fcs = nn.ModuleList()
+ for _ in range(num_fcs):
+ fc = ConvModule(
+ fc_in_channels,
+ fc_channels,
+ kernel_size=1,
+ stride=1,
+ padding=0,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg)
+ self.fcs.append(fc)
+ fc_in_channels = fc_channels
+ fc_in_channels += num_classes if self.coarse_pred_each_layer else 0
+
+ out_channels = 1 if self.class_agnostic else self.num_classes
+ self.fc_logits = nn.Conv1d(
+ fc_in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def init_weights(self):
+ """Initialize last classification layer of MaskPointHead, conv layers
+ are already initialized by ConvModule."""
+ normal_init(self.fc_logits, std=0.001)
+
+ def forward(self, fine_grained_feats, coarse_feats):
+ """Classify each point base on fine grained and coarse feats.
+
+ Args:
+ fine_grained_feats (Tensor): Fine grained feature sampled from FPN,
+ shape (num_rois, in_channels, num_points).
+ coarse_feats (Tensor): Coarse feature sampled from CoarseMaskHead,
+ shape (num_rois, num_classes, num_points).
+
+ Returns:
+ Tensor: Point classification results,
+ shape (num_rois, num_class, num_points).
+ """
+
+ x = torch.cat([fine_grained_feats, coarse_feats], dim=1)
+ for fc in self.fcs:
+ x = fc(x)
+ if self.coarse_pred_each_layer:
+ x = torch.cat((x, coarse_feats), dim=1)
+ return self.fc_logits(x)
+
+ def get_targets(self, rois, rel_roi_points, sampling_results, gt_masks,
+ cfg):
+ """Get training targets of MaskPointHead for all images.
+
+ Args:
+ rois (Tensor): Region of Interest, shape (num_rois, 5).
+ rel_roi_points: Points coordinates relative to RoI, shape
+ (num_rois, num_points, 2).
+ sampling_results (:obj:`SamplingResult`): Sampling result after
+ sampling and assignment.
+ gt_masks (Tensor) : Ground truth segmentation masks of
+ corresponding boxes, shape (num_rois, height, width).
+ cfg (dict): Training cfg.
+
+ Returns:
+ Tensor: Point target, shape (num_rois, num_points).
+ """
+
+ num_imgs = len(sampling_results)
+ rois_list = []
+ rel_roi_points_list = []
+ for batch_ind in range(num_imgs):
+ inds = (rois[:, 0] == batch_ind)
+ rois_list.append(rois[inds])
+ rel_roi_points_list.append(rel_roi_points[inds])
+ pos_assigned_gt_inds_list = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+ cfg_list = [cfg for _ in range(num_imgs)]
+
+ point_targets = map(self._get_target_single, rois_list,
+ rel_roi_points_list, pos_assigned_gt_inds_list,
+ gt_masks, cfg_list)
+ point_targets = list(point_targets)
+
+ if len(point_targets) > 0:
+ point_targets = torch.cat(point_targets)
+
+ return point_targets
+
+ def _get_target_single(self, rois, rel_roi_points, pos_assigned_gt_inds,
+ gt_masks, cfg):
+ """Get training target of MaskPointHead for each image."""
+ num_pos = rois.size(0)
+ num_points = cfg.num_points
+ if num_pos > 0:
+ gt_masks_th = (
+ gt_masks.to_tensor(rois.dtype, rois.device).index_select(
+ 0, pos_assigned_gt_inds))
+ gt_masks_th = gt_masks_th.unsqueeze(1)
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois, rel_roi_points, gt_masks_th.shape[2:])
+ point_targets = point_sample(gt_masks_th,
+ rel_img_points).squeeze(1)
+ else:
+ point_targets = rois.new_zeros((0, num_points))
+ return point_targets
+
+ def loss(self, point_pred, point_targets, labels):
+ """Calculate loss for MaskPointHead.
+
+ Args:
+ point_pred (Tensor): Point predication result, shape
+ (num_rois, num_classes, num_points).
+ point_targets (Tensor): Point targets, shape (num_roi, num_points).
+ labels (Tensor): Class label of corresponding boxes,
+ shape (num_rois, )
+
+ Returns:
+ dict[str, Tensor]: a dictionary of point loss components
+ """
+
+ loss = dict()
+ if self.class_agnostic:
+ loss_point = self.loss_point(point_pred, point_targets,
+ torch.zeros_like(labels))
+ else:
+ loss_point = self.loss_point(point_pred, point_targets, labels)
+ loss['loss_point'] = loss_point
+ return loss
+
+ def _get_uncertainty(self, mask_pred, labels):
+ """Estimate uncertainty based on pred logits.
+
+ We estimate uncertainty as L1 distance between 0.0 and the logits
+ prediction in 'mask_pred' for the foreground class in `classes`.
+
+ Args:
+ mask_pred (Tensor): mask predication logits, shape (num_rois,
+ num_classes, mask_height, mask_width).
+
+ labels (list[Tensor]): Either predicted or ground truth label for
+ each predicted mask, of length num_rois.
+
+ Returns:
+ scores (Tensor): Uncertainty scores with the most uncertain
+ locations having the highest uncertainty score,
+ shape (num_rois, 1, mask_height, mask_width)
+ """
+ if mask_pred.shape[1] == 1:
+ gt_class_logits = mask_pred.clone()
+ else:
+ inds = torch.arange(mask_pred.shape[0], device=mask_pred.device)
+ gt_class_logits = mask_pred[inds, labels].unsqueeze(1)
+ return -torch.abs(gt_class_logits)
+
+ def get_roi_rel_points_train(self, mask_pred, labels, cfg):
+ """Get ``num_points`` most uncertain points with random points during
+ train.
+
+ Sample points in [0, 1] x [0, 1] coordinate space based on their
+ uncertainty. The uncertainties are calculated for each point using
+ '_get_uncertainty()' function that takes point's logit prediction as
+ input.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ labels (list): The ground truth class for each instance.
+ cfg (dict): Training config of point head.
+
+ Returns:
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains the coordinates sampled points.
+ """
+ num_points = cfg.num_points
+ oversample_ratio = cfg.oversample_ratio
+ importance_sample_ratio = cfg.importance_sample_ratio
+ assert oversample_ratio >= 1
+ assert 0 <= importance_sample_ratio <= 1
+ batch_size = mask_pred.shape[0]
+ num_sampled = int(num_points * oversample_ratio)
+ point_coords = torch.rand(
+ batch_size, num_sampled, 2, device=mask_pred.device)
+ point_logits = point_sample(mask_pred, point_coords)
+ # It is crucial to calculate uncertainty based on the sampled
+ # prediction value for the points. Calculating uncertainties of the
+ # coarse predictions first and sampling them for points leads to
+ # incorrect results. To illustrate this: assume uncertainty func(
+ # logits)=-abs(logits), a sampled point between two coarse
+ # predictions with -1 and 1 logits has 0 logits, and therefore 0
+ # uncertainty value. However, if we calculate uncertainties for the
+ # coarse predictions first, both will have -1 uncertainty,
+ # and sampled point will get -1 uncertainty.
+ point_uncertainties = self._get_uncertainty(point_logits, labels)
+ num_uncertain_points = int(importance_sample_ratio * num_points)
+ num_random_points = num_points - num_uncertain_points
+ idx = torch.topk(
+ point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
+ shift = num_sampled * torch.arange(
+ batch_size, dtype=torch.long, device=mask_pred.device)
+ idx += shift[:, None]
+ point_coords = point_coords.view(-1, 2)[idx.view(-1), :].view(
+ batch_size, num_uncertain_points, 2)
+ if num_random_points > 0:
+ rand_roi_coords = torch.rand(
+ batch_size, num_random_points, 2, device=mask_pred.device)
+ point_coords = torch.cat((point_coords, rand_roi_coords), dim=1)
+ return point_coords
+
+ def get_roi_rel_points_test(self, mask_pred, pred_label, cfg):
+ """Get ``num_points`` most uncertain points during test.
+
+ Args:
+ mask_pred (Tensor): A tensor of shape (num_rois, num_classes,
+ mask_height, mask_width) for class-specific or class-agnostic
+ prediction.
+ pred_label (list): The predication class for each instance.
+ cfg (dict): Testing config of point head.
+
+ Returns:
+ point_indices (Tensor): A tensor of shape (num_rois, num_points)
+ that contains indices from [0, mask_height x mask_width) of the
+ most uncertain points.
+ point_coords (Tensor): A tensor of shape (num_rois, num_points, 2)
+ that contains [0, 1] x [0, 1] normalized coordinates of the
+ most uncertain points from the [mask_height, mask_width] grid .
+ """
+ num_points = cfg.subdivision_num_points
+ uncertainty_map = self._get_uncertainty(mask_pred, pred_label)
+ num_rois, _, mask_height, mask_width = uncertainty_map.shape
+ h_step = 1.0 / mask_height
+ w_step = 1.0 / mask_width
+
+ uncertainty_map = uncertainty_map.view(num_rois,
+ mask_height * mask_width)
+ num_points = min(mask_height * mask_width, num_points)
+ point_indices = uncertainty_map.topk(num_points, dim=1)[1]
+ point_coords = uncertainty_map.new_zeros(num_rois, num_points, 2)
+ point_coords[:, :, 0] = w_step / 2.0 + (point_indices %
+ mask_width).float() * w_step
+ point_coords[:, :, 1] = h_step / 2.0 + (point_indices //
+ mask_width).float() * h_step
+ return point_indices, point_coords
diff --git a/mmdet/models/roi_heads/mask_heads/maskiou_head.py b/mmdet/models/roi_heads/mask_heads/maskiou_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..39bcd6a7dbdb089cd19cef811038e0b6a80ab89a
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/maskiou_head.py
@@ -0,0 +1,186 @@
+import numpy as np
+import torch
+import torch.nn as nn
+from mmcv.cnn import Conv2d, Linear, MaxPool2d, kaiming_init, normal_init
+from mmcv.runner import force_fp32
+from torch.nn.modules.utils import _pair
+
+from mmdet.models.builder import HEADS, build_loss
+
+
+@HEADS.register_module()
+class MaskIoUHead(nn.Module):
+ """Mask IoU Head.
+
+ This head predicts the IoU of predicted masks and corresponding gt masks.
+ """
+
+ def __init__(self,
+ num_convs=4,
+ num_fcs=2,
+ roi_feat_size=14,
+ in_channels=256,
+ conv_out_channels=256,
+ fc_out_channels=1024,
+ num_classes=80,
+ loss_iou=dict(type='MSELoss', loss_weight=0.5)):
+ super(MaskIoUHead, self).__init__()
+ self.in_channels = in_channels
+ self.conv_out_channels = conv_out_channels
+ self.fc_out_channels = fc_out_channels
+ self.num_classes = num_classes
+ self.fp16_enabled = False
+
+ self.convs = nn.ModuleList()
+ for i in range(num_convs):
+ if i == 0:
+ # concatenation of mask feature and mask prediction
+ in_channels = self.in_channels + 1
+ else:
+ in_channels = self.conv_out_channels
+ stride = 2 if i == num_convs - 1 else 1
+ self.convs.append(
+ Conv2d(
+ in_channels,
+ self.conv_out_channels,
+ 3,
+ stride=stride,
+ padding=1))
+
+ roi_feat_size = _pair(roi_feat_size)
+ pooled_area = (roi_feat_size[0] // 2) * (roi_feat_size[1] // 2)
+ self.fcs = nn.ModuleList()
+ for i in range(num_fcs):
+ in_channels = (
+ self.conv_out_channels *
+ pooled_area if i == 0 else self.fc_out_channels)
+ self.fcs.append(Linear(in_channels, self.fc_out_channels))
+
+ self.fc_mask_iou = Linear(self.fc_out_channels, self.num_classes)
+ self.relu = nn.ReLU()
+ self.max_pool = MaxPool2d(2, 2)
+ self.loss_iou = build_loss(loss_iou)
+
+ def init_weights(self):
+ for conv in self.convs:
+ kaiming_init(conv)
+ for fc in self.fcs:
+ kaiming_init(
+ fc,
+ a=1,
+ mode='fan_in',
+ nonlinearity='leaky_relu',
+ distribution='uniform')
+ normal_init(self.fc_mask_iou, std=0.01)
+
+ def forward(self, mask_feat, mask_pred):
+ mask_pred = mask_pred.sigmoid()
+ mask_pred_pooled = self.max_pool(mask_pred.unsqueeze(1))
+
+ x = torch.cat((mask_feat, mask_pred_pooled), 1)
+
+ for conv in self.convs:
+ x = self.relu(conv(x))
+ x = x.flatten(1)
+ for fc in self.fcs:
+ x = self.relu(fc(x))
+ mask_iou = self.fc_mask_iou(x)
+ return mask_iou
+
+ @force_fp32(apply_to=('mask_iou_pred', ))
+ def loss(self, mask_iou_pred, mask_iou_targets):
+ pos_inds = mask_iou_targets > 0
+ if pos_inds.sum() > 0:
+ loss_mask_iou = self.loss_iou(mask_iou_pred[pos_inds],
+ mask_iou_targets[pos_inds])
+ else:
+ loss_mask_iou = mask_iou_pred.sum() * 0
+ return dict(loss_mask_iou=loss_mask_iou)
+
+ @force_fp32(apply_to=('mask_pred', ))
+ def get_targets(self, sampling_results, gt_masks, mask_pred, mask_targets,
+ rcnn_train_cfg):
+ """Compute target of mask IoU.
+
+ Mask IoU target is the IoU of the predicted mask (inside a bbox) and
+ the gt mask of corresponding gt mask (the whole instance).
+ The intersection area is computed inside the bbox, and the gt mask area
+ is computed with two steps, firstly we compute the gt area inside the
+ bbox, then divide it by the area ratio of gt area inside the bbox and
+ the gt area of the whole instance.
+
+ Args:
+ sampling_results (list[:obj:`SamplingResult`]): sampling results.
+ gt_masks (BitmapMask | PolygonMask): Gt masks (the whole instance)
+ of each image, with the same shape of the input image.
+ mask_pred (Tensor): Predicted masks of each positive proposal,
+ shape (num_pos, h, w).
+ mask_targets (Tensor): Gt mask of each positive proposal,
+ binary map of the shape (num_pos, h, w).
+ rcnn_train_cfg (dict): Training config for R-CNN part.
+
+ Returns:
+ Tensor: mask iou target (length == num positive).
+ """
+ pos_proposals = [res.pos_bboxes for res in sampling_results]
+ pos_assigned_gt_inds = [
+ res.pos_assigned_gt_inds for res in sampling_results
+ ]
+
+ # compute the area ratio of gt areas inside the proposals and
+ # the whole instance
+ area_ratios = map(self._get_area_ratio, pos_proposals,
+ pos_assigned_gt_inds, gt_masks)
+ area_ratios = torch.cat(list(area_ratios))
+ assert mask_targets.size(0) == area_ratios.size(0)
+
+ mask_pred = (mask_pred > rcnn_train_cfg.mask_thr_binary).float()
+ mask_pred_areas = mask_pred.sum((-1, -2))
+
+ # mask_pred and mask_targets are binary maps
+ overlap_areas = (mask_pred * mask_targets).sum((-1, -2))
+
+ # compute the mask area of the whole instance
+ gt_full_areas = mask_targets.sum((-1, -2)) / (area_ratios + 1e-7)
+
+ mask_iou_targets = overlap_areas / (
+ mask_pred_areas + gt_full_areas - overlap_areas)
+ return mask_iou_targets
+
+ def _get_area_ratio(self, pos_proposals, pos_assigned_gt_inds, gt_masks):
+ """Compute area ratio of the gt mask inside the proposal and the gt
+ mask of the corresponding instance."""
+ num_pos = pos_proposals.size(0)
+ if num_pos > 0:
+ area_ratios = []
+ proposals_np = pos_proposals.cpu().numpy()
+ pos_assigned_gt_inds = pos_assigned_gt_inds.cpu().numpy()
+ # compute mask areas of gt instances (batch processing for speedup)
+ gt_instance_mask_area = gt_masks.areas
+ for i in range(num_pos):
+ gt_mask = gt_masks[pos_assigned_gt_inds[i]]
+
+ # crop the gt mask inside the proposal
+ bbox = proposals_np[i, :].astype(np.int32)
+ gt_mask_in_proposal = gt_mask.crop(bbox)
+
+ ratio = gt_mask_in_proposal.areas[0] / (
+ gt_instance_mask_area[pos_assigned_gt_inds[i]] + 1e-7)
+ area_ratios.append(ratio)
+ area_ratios = torch.from_numpy(np.stack(area_ratios)).float().to(
+ pos_proposals.device)
+ else:
+ area_ratios = pos_proposals.new_zeros((0, ))
+ return area_ratios
+
+ @force_fp32(apply_to=('mask_iou_pred', ))
+ def get_mask_scores(self, mask_iou_pred, det_bboxes, det_labels):
+ """Get the mask scores.
+
+ mask_score = bbox_score * mask_iou
+ """
+ inds = range(det_labels.size(0))
+ mask_scores = mask_iou_pred[inds, det_labels] * det_bboxes[inds, -1]
+ mask_scores = mask_scores.cpu().numpy()
+ det_labels = det_labels.cpu().numpy()
+ return [mask_scores[det_labels == i] for i in range(self.num_classes)]
diff --git a/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..983a2d9db71a3b2b4980996725fdafb0b412b413
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/scnet_mask_head.py
@@ -0,0 +1,27 @@
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+from .fcn_mask_head import FCNMaskHead
+
+
+@HEADS.register_module()
+class SCNetMaskHead(FCNMaskHead):
+ """Mask head for `SCNet `_.
+
+ Args:
+ conv_to_res (bool, optional): if True, change the conv layers to
+ ``SimplifiedBasicBlock``.
+ """
+
+ def __init__(self, conv_to_res=True, **kwargs):
+ super(SCNetMaskHead, self).__init__(**kwargs)
+ self.conv_to_res = conv_to_res
+ if conv_to_res:
+ assert self.conv_kernel_size == 3
+ self.num_res_blocks = self.num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ self.in_channels,
+ self.conv_out_channels,
+ self.num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
diff --git a/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..df85a0112d27d97301fff56189f99bee0bf8efa5
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_heads/scnet_semantic_head.py
@@ -0,0 +1,27 @@
+from mmdet.models.builder import HEADS
+from mmdet.models.utils import ResLayer, SimplifiedBasicBlock
+from .fused_semantic_head import FusedSemanticHead
+
+
+@HEADS.register_module()
+class SCNetSemanticHead(FusedSemanticHead):
+ """Mask head for `SCNet `_.
+
+ Args:
+ conv_to_res (bool, optional): if True, change the conv layers to
+ ``SimplifiedBasicBlock``.
+ """
+
+ def __init__(self, conv_to_res=True, **kwargs):
+ super(SCNetSemanticHead, self).__init__(**kwargs)
+ self.conv_to_res = conv_to_res
+ if self.conv_to_res:
+ num_res_blocks = self.num_convs // 2
+ self.convs = ResLayer(
+ SimplifiedBasicBlock,
+ self.in_channels,
+ self.conv_out_channels,
+ num_res_blocks,
+ conv_cfg=self.conv_cfg,
+ norm_cfg=self.norm_cfg)
+ self.num_convs = num_res_blocks
diff --git a/mmdet/models/roi_heads/mask_scoring_roi_head.py b/mmdet/models/roi_heads/mask_scoring_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6e55c7752209cb5c15eab689ad9e8ac1fef1b66
--- /dev/null
+++ b/mmdet/models/roi_heads/mask_scoring_roi_head.py
@@ -0,0 +1,122 @@
+import torch
+
+from mmdet.core import bbox2roi
+from ..builder import HEADS, build_head
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class MaskScoringRoIHead(StandardRoIHead):
+ """Mask Scoring RoIHead for Mask Scoring RCNN.
+
+ https://arxiv.org/abs/1903.00241
+ """
+
+ def __init__(self, mask_iou_head, **kwargs):
+ assert mask_iou_head is not None
+ super(MaskScoringRoIHead, self).__init__(**kwargs)
+ self.mask_iou_head = build_head(mask_iou_head)
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ super(MaskScoringRoIHead, self).init_weights(pretrained)
+ self.mask_iou_head.init_weights()
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for Mask head in
+ training."""
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ mask_results = super(MaskScoringRoIHead,
+ self)._mask_forward_train(x, sampling_results,
+ bbox_feats, gt_masks,
+ img_metas)
+ if mask_results['loss_mask'] is None:
+ return mask_results
+
+ # mask iou head forward and loss
+ pos_mask_pred = mask_results['mask_pred'][
+ range(mask_results['mask_pred'].size(0)), pos_labels]
+ mask_iou_pred = self.mask_iou_head(mask_results['mask_feats'],
+ pos_mask_pred)
+ pos_mask_iou_pred = mask_iou_pred[range(mask_iou_pred.size(0)),
+ pos_labels]
+
+ mask_iou_targets = self.mask_iou_head.get_targets(
+ sampling_results, gt_masks, pos_mask_pred,
+ mask_results['mask_targets'], self.train_cfg)
+ loss_mask_iou = self.mask_iou_head.loss(pos_mask_iou_pred,
+ mask_iou_targets)
+ mask_results['loss_mask'].update(loss_mask_iou)
+ return mask_results
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Obtain mask prediction without augmentation."""
+ # image shapes of images in the batch
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ num_imgs = len(det_bboxes)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ num_classes = self.mask_head.num_classes
+ segm_results = [[[] for _ in range(num_classes)]
+ for _ in range(num_imgs)]
+ mask_scores = [[[] for _ in range(num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self._mask_forward(x, mask_rois)
+ concat_det_labels = torch.cat(det_labels)
+ # get mask scores with mask iou head
+ mask_feats = mask_results['mask_feats']
+ mask_pred = mask_results['mask_pred']
+ mask_iou_pred = self.mask_iou_head(
+ mask_feats, mask_pred[range(concat_det_labels.size(0)),
+ concat_det_labels])
+ # split batch mask prediction back to each image
+ num_bboxes_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ mask_preds = mask_pred.split(num_bboxes_per_img, 0)
+ mask_iou_preds = mask_iou_pred.split(num_bboxes_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ mask_scores = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ mask_scores.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], _bboxes[i], det_labels[i],
+ self.test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ # get mask scores with mask iou head
+ mask_score = self.mask_iou_head.get_mask_scores(
+ mask_iou_preds[i], det_bboxes[i], det_labels[i])
+ segm_results.append(segm_result)
+ mask_scores.append(mask_score)
+ return list(zip(segm_results, mask_scores))
diff --git a/mmdet/models/roi_heads/pisa_roi_head.py b/mmdet/models/roi_heads/pisa_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..e01113629837eb9c065ba40cd4025899b7bd0172
--- /dev/null
+++ b/mmdet/models/roi_heads/pisa_roi_head.py
@@ -0,0 +1,159 @@
+from mmdet.core import bbox2roi
+from ..builder import HEADS
+from ..losses.pisa_loss import carl_loss, isr_p
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class PISARoIHead(StandardRoIHead):
+ r"""The RoI head for `Prime Sample Attention in Object Detection
+ `_."""
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """Forward function for training.
+
+ Args:
+ x (list[Tensor]): List of multi-level img features.
+ img_metas (list[dict]): List of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): List of region proposals.
+ gt_bboxes (list[Tensor]): Each item are the truth boxes for each
+ image in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): Class indices corresponding to each box
+ gt_bboxes_ignore (list[Tensor], optional): Specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : True segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ neg_label_weights = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ # neg label weight is obtained by sampling when using ISR-N
+ neg_label_weight = None
+ if isinstance(sampling_result, tuple):
+ sampling_result, neg_label_weight = sampling_result
+ sampling_results.append(sampling_result)
+ neg_label_weights.append(neg_label_weight)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ neg_label_weights=neg_label_weights)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ return losses
+
+ def _bbox_forward(self, x, rois):
+ """Box forward function used in both training and testing."""
+ # TODO: a more flexible way to decide which feature maps to use
+ bbox_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ bbox_feats = self.shared_head(bbox_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ img_metas,
+ neg_label_weights=None):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+
+ # neg_label_weights obtained by sampler is image-wise, mapping back to
+ # the corresponding location in label weights
+ if neg_label_weights[0] is not None:
+ label_weights = bbox_targets[1]
+ cur_num_rois = 0
+ for i in range(len(sampling_results)):
+ num_pos = sampling_results[i].pos_inds.size(0)
+ num_neg = sampling_results[i].neg_inds.size(0)
+ label_weights[cur_num_rois + num_pos:cur_num_rois + num_pos +
+ num_neg] = neg_label_weights[i]
+ cur_num_rois += num_pos + num_neg
+
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+
+ # Apply ISR-P
+ isr_cfg = self.train_cfg.get('isr', None)
+ if isr_cfg is not None:
+ bbox_targets = isr_p(
+ cls_score,
+ bbox_pred,
+ bbox_targets,
+ rois,
+ sampling_results,
+ self.bbox_head.loss_cls,
+ self.bbox_head.bbox_coder,
+ **isr_cfg,
+ num_class=self.bbox_head.num_classes)
+ loss_bbox = self.bbox_head.loss(cls_score, bbox_pred, rois,
+ *bbox_targets)
+
+ # Add CARL Loss
+ carl_cfg = self.train_cfg.get('carl', None)
+ if carl_cfg is not None:
+ loss_carl = carl_loss(
+ cls_score,
+ bbox_targets[0],
+ bbox_pred,
+ bbox_targets[2],
+ self.bbox_head.loss_bbox,
+ **carl_cfg,
+ num_class=self.bbox_head.num_classes)
+ loss_bbox.update(loss_carl)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
diff --git a/mmdet/models/roi_heads/point_rend_roi_head.py b/mmdet/models/roi_heads/point_rend_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..478cdf5bff6779e9291f94c543205289036ea2c6
--- /dev/null
+++ b/mmdet/models/roi_heads/point_rend_roi_head.py
@@ -0,0 +1,218 @@
+# Modified from https://github.com/facebookresearch/detectron2/tree/master/projects/PointRend # noqa
+
+import torch
+import torch.nn.functional as F
+from mmcv.ops import point_sample, rel_roi_point_to_rel_img_point
+
+from mmdet.core import bbox2roi, bbox_mapping, merge_aug_masks
+from .. import builder
+from ..builder import HEADS
+from .standard_roi_head import StandardRoIHead
+
+
+@HEADS.register_module()
+class PointRendRoIHead(StandardRoIHead):
+ """`PointRend `_."""
+
+ def __init__(self, point_head, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ assert self.with_bbox and self.with_mask
+ self.init_point_head(point_head)
+
+ def init_point_head(self, point_head):
+ """Initialize ``point_head``"""
+ self.point_head = builder.build_head(point_head)
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ """
+ super().init_weights(pretrained)
+ self.point_head.init_weights()
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for mask head and point head
+ in training."""
+ mask_results = super()._mask_forward_train(x, sampling_results,
+ bbox_feats, gt_masks,
+ img_metas)
+ if mask_results['loss_mask'] is not None:
+ loss_point = self._mask_point_forward_train(
+ x, sampling_results, mask_results['mask_pred'], gt_masks,
+ img_metas)
+ mask_results['loss_mask'].update(loss_point)
+
+ return mask_results
+
+ def _mask_point_forward_train(self, x, sampling_results, mask_pred,
+ gt_masks, img_metas):
+ """Run forward function and calculate loss for point head in
+ training."""
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ rel_roi_points = self.point_head.get_roi_rel_points_train(
+ mask_pred, pos_labels, cfg=self.train_cfg)
+ rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, rois, rel_roi_points, img_metas)
+ coarse_point_feats = point_sample(mask_pred, rel_roi_points)
+ mask_point_pred = self.point_head(fine_grained_point_feats,
+ coarse_point_feats)
+ mask_point_target = self.point_head.get_targets(
+ rois, rel_roi_points, sampling_results, gt_masks, self.train_cfg)
+ loss_mask_point = self.point_head.loss(mask_point_pred,
+ mask_point_target, pos_labels)
+
+ return loss_mask_point
+
+ def _get_fine_grained_point_feats(self, x, rois, rel_roi_points,
+ img_metas):
+ """Sample fine grained feats from each level feature map and
+ concatenate them together."""
+ num_imgs = len(img_metas)
+ fine_grained_feats = []
+ for idx in range(self.mask_roi_extractor.num_inputs):
+ feats = x[idx]
+ spatial_scale = 1. / float(
+ self.mask_roi_extractor.featmap_strides[idx])
+ point_feats = []
+ for batch_ind in range(num_imgs):
+ # unravel batch dim
+ feat = feats[batch_ind].unsqueeze(0)
+ inds = (rois[:, 0].long() == batch_ind)
+ if inds.any():
+ rel_img_points = rel_roi_point_to_rel_img_point(
+ rois[inds], rel_roi_points[inds], feat.shape[2:],
+ spatial_scale).unsqueeze(0)
+ point_feat = point_sample(feat, rel_img_points)
+ point_feat = point_feat.squeeze(0).transpose(0, 1)
+ point_feats.append(point_feat)
+ fine_grained_feats.append(torch.cat(point_feats, dim=0))
+ return torch.cat(fine_grained_feats, dim=1)
+
+ def _mask_point_forward_test(self, x, rois, label_pred, mask_pred,
+ img_metas):
+ """Mask refining process with point head in testing."""
+ refined_mask_pred = mask_pred.clone()
+ for subdivision_step in range(self.test_cfg.subdivision_steps):
+ refined_mask_pred = F.interpolate(
+ refined_mask_pred,
+ scale_factor=self.test_cfg.scale_factor,
+ mode='bilinear',
+ align_corners=False)
+ # If `subdivision_num_points` is larger or equal to the
+ # resolution of the next step, then we can skip this step
+ num_rois, channels, mask_height, mask_width = \
+ refined_mask_pred.shape
+ if (self.test_cfg.subdivision_num_points >=
+ self.test_cfg.scale_factor**2 * mask_height * mask_width
+ and
+ subdivision_step < self.test_cfg.subdivision_steps - 1):
+ continue
+ point_indices, rel_roi_points = \
+ self.point_head.get_roi_rel_points_test(
+ refined_mask_pred, label_pred, cfg=self.test_cfg)
+ fine_grained_point_feats = self._get_fine_grained_point_feats(
+ x, rois, rel_roi_points, img_metas)
+ coarse_point_feats = point_sample(mask_pred, rel_roi_points)
+ mask_point_pred = self.point_head(fine_grained_point_feats,
+ coarse_point_feats)
+
+ point_indices = point_indices.unsqueeze(1).expand(-1, channels, -1)
+ refined_mask_pred = refined_mask_pred.reshape(
+ num_rois, channels, mask_height * mask_width)
+ refined_mask_pred = refined_mask_pred.scatter_(
+ 2, point_indices, mask_point_pred)
+ refined_mask_pred = refined_mask_pred.view(num_rois, channels,
+ mask_height, mask_width)
+
+ return refined_mask_pred
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Obtain mask prediction without augmentation."""
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+ num_imgs = len(det_bboxes)
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ segm_results = [[[] for _ in range(self.mask_head.num_classes)]
+ for _ in range(num_imgs)]
+ else:
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i][:, :4]
+ for i in range(len(det_bboxes))
+ ]
+ mask_rois = bbox2roi(_bboxes)
+ mask_results = self._mask_forward(x, mask_rois)
+ # split batch mask prediction back to each image
+ mask_pred = mask_results['mask_pred']
+ num_mask_roi_per_img = [len(det_bbox) for det_bbox in det_bboxes]
+ mask_preds = mask_pred.split(num_mask_roi_per_img, 0)
+ mask_rois = mask_rois.split(num_mask_roi_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ x_i = [xx[[i]] for xx in x]
+ mask_rois_i = mask_rois[i]
+ mask_rois_i[:, 0] = 0 # TODO: remove this hack
+ mask_pred_i = self._mask_point_forward_test(
+ x_i, mask_rois_i, det_labels[i], mask_preds[i],
+ [img_metas])
+ segm_result = self.mask_head.get_seg_masks(
+ mask_pred_i, _bboxes[i], det_labels[i], self.test_cfg,
+ ori_shapes[i], scale_factors[i], rescale)
+ segm_results.append(segm_result)
+ return segm_results
+
+ def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
+ """Test for mask head with test time augmentation."""
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta in zip(feats, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip)
+ mask_rois = bbox2roi([_bboxes])
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_results['mask_pred'] = self._mask_point_forward_test(
+ x, mask_rois, det_labels, mask_results['mask_pred'],
+ img_metas)
+ # convert to numpy array to save memory
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ self.test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return segm_result
diff --git a/mmdet/models/roi_heads/roi_extractors/__init__.py b/mmdet/models/roi_heads/roi_extractors/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6ec0ecc3063cd23c2463f2f53f1c2a83b04d43b
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/__init__.py
@@ -0,0 +1,7 @@
+from .generic_roi_extractor import GenericRoIExtractor
+from .single_level_roi_extractor import SingleRoIExtractor
+
+__all__ = [
+ 'SingleRoIExtractor',
+ 'GenericRoIExtractor',
+]
diff --git a/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..847932547c6c309ae38b45dc43ac0ef8ca66d347
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/base_roi_extractor.py
@@ -0,0 +1,83 @@
+from abc import ABCMeta, abstractmethod
+
+import torch
+import torch.nn as nn
+from mmcv import ops
+
+
+class BaseRoIExtractor(nn.Module, metaclass=ABCMeta):
+ """Base class for RoI extractor.
+
+ Args:
+ roi_layer (dict): Specify RoI layer type and arguments.
+ out_channels (int): Output channels of RoI layers.
+ featmap_strides (List[int]): Strides of input feature maps.
+ """
+
+ def __init__(self, roi_layer, out_channels, featmap_strides):
+ super(BaseRoIExtractor, self).__init__()
+ self.roi_layers = self.build_roi_layers(roi_layer, featmap_strides)
+ self.out_channels = out_channels
+ self.featmap_strides = featmap_strides
+ self.fp16_enabled = False
+
+ @property
+ def num_inputs(self):
+ """int: Number of input feature maps."""
+ return len(self.featmap_strides)
+
+ def init_weights(self):
+ pass
+
+ def build_roi_layers(self, layer_cfg, featmap_strides):
+ """Build RoI operator to extract feature from each level feature map.
+
+ Args:
+ layer_cfg (dict): Dictionary to construct and config RoI layer
+ operation. Options are modules under ``mmcv/ops`` such as
+ ``RoIAlign``.
+ featmap_strides (List[int]): The stride of input feature map w.r.t
+ to the original image size, which would be used to scale RoI
+ coordinate (original image coordinate system) to feature
+ coordinate system.
+
+ Returns:
+ nn.ModuleList: The RoI extractor modules for each level feature
+ map.
+ """
+
+ cfg = layer_cfg.copy()
+ layer_type = cfg.pop('type')
+ assert hasattr(ops, layer_type)
+ layer_cls = getattr(ops, layer_type)
+ roi_layers = nn.ModuleList(
+ [layer_cls(spatial_scale=1 / s, **cfg) for s in featmap_strides])
+ return roi_layers
+
+ def roi_rescale(self, rois, scale_factor):
+ """Scale RoI coordinates by scale factor.
+
+ Args:
+ rois (torch.Tensor): RoI (Region of Interest), shape (n, 5)
+ scale_factor (float): Scale factor that RoI will be multiplied by.
+
+ Returns:
+ torch.Tensor: Scaled RoI.
+ """
+
+ cx = (rois[:, 1] + rois[:, 3]) * 0.5
+ cy = (rois[:, 2] + rois[:, 4]) * 0.5
+ w = rois[:, 3] - rois[:, 1]
+ h = rois[:, 4] - rois[:, 2]
+ new_w = w * scale_factor
+ new_h = h * scale_factor
+ x1 = cx - new_w * 0.5
+ x2 = cx + new_w * 0.5
+ y1 = cy - new_h * 0.5
+ y2 = cy + new_h * 0.5
+ new_rois = torch.stack((rois[:, 0], x1, y1, x2, y2), dim=-1)
+ return new_rois
+
+ @abstractmethod
+ def forward(self, feats, rois, roi_scale_factor=None):
+ pass
diff --git a/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..80c25bb8fde7844c994bfc1f4ae1a2d960cbf3d6
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/generic_roi_extractor.py
@@ -0,0 +1,83 @@
+from mmcv.cnn.bricks import build_plugin_layer
+from mmcv.runner import force_fp32
+
+from mmdet.models.builder import ROI_EXTRACTORS
+from .base_roi_extractor import BaseRoIExtractor
+
+
+@ROI_EXTRACTORS.register_module()
+class GenericRoIExtractor(BaseRoIExtractor):
+ """Extract RoI features from all level feature maps levels.
+
+ This is the implementation of `A novel Region of Interest Extraction Layer
+ for Instance Segmentation `_.
+
+ Args:
+ aggregation (str): The method to aggregate multiple feature maps.
+ Options are 'sum', 'concat'. Default: 'sum'.
+ pre_cfg (dict | None): Specify pre-processing modules. Default: None.
+ post_cfg (dict | None): Specify post-processing modules. Default: None.
+ kwargs (keyword arguments): Arguments that are the same
+ as :class:`BaseRoIExtractor`.
+ """
+
+ def __init__(self,
+ aggregation='sum',
+ pre_cfg=None,
+ post_cfg=None,
+ **kwargs):
+ super(GenericRoIExtractor, self).__init__(**kwargs)
+
+ assert aggregation in ['sum', 'concat']
+
+ self.aggregation = aggregation
+ self.with_post = post_cfg is not None
+ self.with_pre = pre_cfg is not None
+ # build pre/post processing modules
+ if self.with_post:
+ self.post_module = build_plugin_layer(post_cfg, '_post_module')[1]
+ if self.with_pre:
+ self.pre_module = build_plugin_layer(pre_cfg, '_pre_module')[1]
+
+ @force_fp32(apply_to=('feats', ), out_fp16=True)
+ def forward(self, feats, rois, roi_scale_factor=None):
+ """Forward function."""
+ if len(feats) == 1:
+ return self.roi_layers[0](feats[0], rois)
+
+ out_size = self.roi_layers[0].output_size
+ num_levels = len(feats)
+ roi_feats = feats[0].new_zeros(
+ rois.size(0), self.out_channels, *out_size)
+
+ # some times rois is an empty tensor
+ if roi_feats.shape[0] == 0:
+ return roi_feats
+
+ if roi_scale_factor is not None:
+ rois = self.roi_rescale(rois, roi_scale_factor)
+
+ # mark the starting channels for concat mode
+ start_channels = 0
+ for i in range(num_levels):
+ roi_feats_t = self.roi_layers[i](feats[i], rois)
+ end_channels = start_channels + roi_feats_t.size(1)
+ if self.with_pre:
+ # apply pre-processing to a RoI extracted from each layer
+ roi_feats_t = self.pre_module(roi_feats_t)
+ if self.aggregation == 'sum':
+ # and sum them all
+ roi_feats += roi_feats_t
+ else:
+ # and concat them along channel dimension
+ roi_feats[:, start_channels:end_channels] = roi_feats_t
+ # update channels starting position
+ start_channels = end_channels
+ # check if concat channels match at the end
+ if self.aggregation == 'concat':
+ assert start_channels == self.out_channels
+
+ if self.with_post:
+ # apply post-processing before return the result
+ roi_feats = self.post_module(roi_feats)
+ return roi_feats
diff --git a/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cfc838f23270a1ae4d70f90059b67a890850e981
--- /dev/null
+++ b/mmdet/models/roi_heads/roi_extractors/single_level_roi_extractor.py
@@ -0,0 +1,108 @@
+import torch
+from mmcv.runner import force_fp32
+
+from mmdet.models.builder import ROI_EXTRACTORS
+from .base_roi_extractor import BaseRoIExtractor
+
+
+@ROI_EXTRACTORS.register_module()
+class SingleRoIExtractor(BaseRoIExtractor):
+ """Extract RoI features from a single level feature map.
+
+ If there are multiple input feature levels, each RoI is mapped to a level
+ according to its scale. The mapping rule is proposed in
+ `FPN `_.
+
+ Args:
+ roi_layer (dict): Specify RoI layer type and arguments.
+ out_channels (int): Output channels of RoI layers.
+ featmap_strides (List[int]): Strides of input feature maps.
+ finest_scale (int): Scale threshold of mapping to level 0. Default: 56.
+ """
+
+ def __init__(self,
+ roi_layer,
+ out_channels,
+ featmap_strides,
+ finest_scale=56):
+ super(SingleRoIExtractor, self).__init__(roi_layer, out_channels,
+ featmap_strides)
+ self.finest_scale = finest_scale
+
+ def map_roi_levels(self, rois, num_levels):
+ """Map rois to corresponding feature levels by scales.
+
+ - scale < finest_scale * 2: level 0
+ - finest_scale * 2 <= scale < finest_scale * 4: level 1
+ - finest_scale * 4 <= scale < finest_scale * 8: level 2
+ - scale >= finest_scale * 8: level 3
+
+ Args:
+ rois (Tensor): Input RoIs, shape (k, 5).
+ num_levels (int): Total level number.
+
+ Returns:
+ Tensor: Level index (0-based) of each RoI, shape (k, )
+ """
+ scale = torch.sqrt(
+ (rois[:, 3] - rois[:, 1]) * (rois[:, 4] - rois[:, 2]))
+ target_lvls = torch.floor(torch.log2(scale / self.finest_scale + 1e-6))
+ target_lvls = target_lvls.clamp(min=0, max=num_levels - 1).long()
+ return target_lvls
+
+ @force_fp32(apply_to=('feats', ), out_fp16=True)
+ def forward(self, feats, rois, roi_scale_factor=None):
+ """Forward function."""
+ out_size = self.roi_layers[0].output_size
+ num_levels = len(feats)
+ expand_dims = (-1, self.out_channels * out_size[0] * out_size[1])
+ if torch.onnx.is_in_onnx_export():
+ # Work around to export mask-rcnn to onnx
+ roi_feats = rois[:, :1].clone().detach()
+ roi_feats = roi_feats.expand(*expand_dims)
+ roi_feats = roi_feats.reshape(-1, self.out_channels, *out_size)
+ roi_feats = roi_feats * 0
+ else:
+ roi_feats = feats[0].new_zeros(
+ rois.size(0), self.out_channels, *out_size)
+ # TODO: remove this when parrots supports
+ if torch.__version__ == 'parrots':
+ roi_feats.requires_grad = True
+
+ if num_levels == 1:
+ if len(rois) == 0:
+ return roi_feats
+ return self.roi_layers[0](feats[0], rois)
+
+ target_lvls = self.map_roi_levels(rois, num_levels)
+
+ if roi_scale_factor is not None:
+ rois = self.roi_rescale(rois, roi_scale_factor)
+
+ for i in range(num_levels):
+ mask = target_lvls == i
+ if torch.onnx.is_in_onnx_export():
+ # To keep all roi_align nodes exported to onnx
+ # and skip nonzero op
+ mask = mask.float().unsqueeze(-1).expand(*expand_dims).reshape(
+ roi_feats.shape)
+ roi_feats_t = self.roi_layers[i](feats[i], rois)
+ roi_feats_t *= mask
+ roi_feats += roi_feats_t
+ continue
+ inds = mask.nonzero(as_tuple=False).squeeze(1)
+ if inds.numel() > 0:
+ rois_ = rois[inds]
+ roi_feats_t = self.roi_layers[i](feats[i], rois_)
+ roi_feats[inds] = roi_feats_t
+ else:
+ # Sometimes some pyramid levels will not be used for RoI
+ # feature extraction and this will cause an incomplete
+ # computation graph in one GPU, which is different from those
+ # in other GPUs and will cause a hanging error.
+ # Therefore, we add it to ensure each feature pyramid is
+ # included in the computation graph to avoid runtime bugs.
+ roi_feats += sum(
+ x.view(-1)[0]
+ for x in self.parameters()) * 0. + feats[i].sum() * 0.
+ return roi_feats
diff --git a/mmdet/models/roi_heads/scnet_roi_head.py b/mmdet/models/roi_heads/scnet_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..85aaa2f0600afbdfc8b0917cb5f341740776a603
--- /dev/null
+++ b/mmdet/models/roi_heads/scnet_roi_head.py
@@ -0,0 +1,582 @@
+import torch
+import torch.nn.functional as F
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+from ..builder import HEADS, build_head, build_roi_extractor
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class SCNetRoIHead(CascadeRoIHead):
+ """RoIHead for `SCNet `_.
+
+ Args:
+ num_stages (int): number of cascade stages.
+ stage_loss_weights (list): loss weight of cascade stages.
+ semantic_roi_extractor (dict): config to init semantic roi extractor.
+ semantic_head (dict): config to init semantic head.
+ feat_relay_head (dict): config to init feature_relay_head.
+ glbctx_head (dict): config to init global context head.
+ """
+
+ def __init__(self,
+ num_stages,
+ stage_loss_weights,
+ semantic_roi_extractor=None,
+ semantic_head=None,
+ feat_relay_head=None,
+ glbctx_head=None,
+ **kwargs):
+ super(SCNetRoIHead, self).__init__(num_stages, stage_loss_weights,
+ **kwargs)
+ assert self.with_bbox and self.with_mask
+ assert not self.with_shared_head # shared head is not supported
+
+ if semantic_head is not None:
+ self.semantic_roi_extractor = build_roi_extractor(
+ semantic_roi_extractor)
+ self.semantic_head = build_head(semantic_head)
+
+ if feat_relay_head is not None:
+ self.feat_relay_head = build_head(feat_relay_head)
+
+ if glbctx_head is not None:
+ self.glbctx_head = build_head(glbctx_head)
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize ``mask_head``"""
+ if mask_roi_extractor is not None:
+ self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
+ self.mask_head = build_head(mask_head)
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ for i in range(self.num_stages):
+ if self.with_bbox:
+ self.bbox_roi_extractor[i].init_weights()
+ self.bbox_head[i].init_weights()
+ if self.with_mask:
+ self.mask_roi_extractor.init_weights()
+ self.mask_head.init_weights()
+ if self.with_semantic:
+ self.semantic_head.init_weights()
+ if self.with_glbctx:
+ self.glbctx_head.init_weights()
+ if self.with_feat_relay:
+ self.feat_relay_head.init_weights()
+
+ @property
+ def with_semantic(self):
+ """bool: whether the head has semantic head"""
+ return hasattr(self,
+ 'semantic_head') and self.semantic_head is not None
+
+ @property
+ def with_feat_relay(self):
+ """bool: whether the head has feature relay head"""
+ return (hasattr(self, 'feat_relay_head')
+ and self.feat_relay_head is not None)
+
+ @property
+ def with_glbctx(self):
+ """bool: whether the head has global context head"""
+ return hasattr(self, 'glbctx_head') and self.glbctx_head is not None
+
+ def _fuse_glbctx(self, roi_feats, glbctx_feat, rois):
+ """Fuse global context feats with roi feats."""
+ assert roi_feats.size(0) == rois.size(0)
+ img_inds = torch.unique(rois[:, 0].cpu(), sorted=True).long()
+ fused_feats = torch.zeros_like(roi_feats)
+ for img_id in img_inds:
+ inds = (rois[:, 0] == img_id.item())
+ fused_feats[inds] = roi_feats[inds] + glbctx_feat[img_id]
+ return fused_feats
+
+ def _slice_pos_feats(self, feats, sampling_results):
+ """Get features from pos rois."""
+ num_rois = [res.bboxes.size(0) for res in sampling_results]
+ num_pos_rois = [res.pos_bboxes.size(0) for res in sampling_results]
+ inds = torch.zeros(sum(num_rois), dtype=torch.bool)
+ start = 0
+ for i in range(len(num_rois)):
+ start = 0 if i == 0 else start + num_rois[i - 1]
+ stop = start + num_pos_rois[i]
+ inds[start:stop] = 1
+ sliced_feats = feats[inds]
+ return sliced_feats
+
+ def _bbox_forward(self,
+ stage,
+ x,
+ rois,
+ semantic_feat=None,
+ glbctx_feat=None):
+ """Box head forward function used in both training and testing."""
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(
+ x[:len(bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_semantic and semantic_feat is not None:
+ bbox_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if bbox_semantic_feat.shape[-2:] != bbox_feats.shape[-2:]:
+ bbox_semantic_feat = F.adaptive_avg_pool2d(
+ bbox_semantic_feat, bbox_feats.shape[-2:])
+ bbox_feats += bbox_semantic_feat
+ if self.with_glbctx and glbctx_feat is not None:
+ bbox_feats = self._fuse_glbctx(bbox_feats, glbctx_feat, rois)
+ cls_score, bbox_pred, relayed_feat = bbox_head(
+ bbox_feats, return_shared_feat=True)
+
+ bbox_results = dict(
+ cls_score=cls_score,
+ bbox_pred=bbox_pred,
+ relayed_feat=relayed_feat)
+ return bbox_results
+
+ def _mask_forward(self,
+ x,
+ rois,
+ semantic_feat=None,
+ glbctx_feat=None,
+ relayed_feat=None):
+ """Mask head forward function used in both training and testing."""
+ mask_feats = self.mask_roi_extractor(
+ x[:self.mask_roi_extractor.num_inputs], rois)
+ if self.with_semantic and semantic_feat is not None:
+ mask_semantic_feat = self.semantic_roi_extractor([semantic_feat],
+ rois)
+ if mask_semantic_feat.shape[-2:] != mask_feats.shape[-2:]:
+ mask_semantic_feat = F.adaptive_avg_pool2d(
+ mask_semantic_feat, mask_feats.shape[-2:])
+ mask_feats += mask_semantic_feat
+ if self.with_glbctx and glbctx_feat is not None:
+ mask_feats = self._fuse_glbctx(mask_feats, glbctx_feat, rois)
+ if self.with_feat_relay and relayed_feat is not None:
+ mask_feats = mask_feats + relayed_feat
+ mask_pred = self.mask_head(mask_feats)
+ mask_results = dict(mask_pred=mask_pred)
+
+ return mask_results
+
+ def _bbox_forward_train(self,
+ stage,
+ x,
+ sampling_results,
+ gt_bboxes,
+ gt_labels,
+ rcnn_train_cfg,
+ semantic_feat=None,
+ glbctx_feat=None):
+ """Run forward function and calculate loss for box head in training."""
+ bbox_head = self.bbox_head[stage]
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(
+ stage,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+
+ bbox_targets = bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, rcnn_train_cfg)
+ loss_bbox = bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(
+ loss_bbox=loss_bbox, rois=rois, bbox_targets=bbox_targets)
+ return bbox_results
+
+ def _mask_forward_train(self,
+ x,
+ sampling_results,
+ gt_masks,
+ rcnn_train_cfg,
+ semantic_feat=None,
+ glbctx_feat=None,
+ relayed_feat=None):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(
+ x,
+ pos_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+
+ mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
+ rcnn_train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head.loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results = loss_mask
+ return mask_results
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None,
+ gt_semantic_seg=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+
+ proposal_list (list[Tensors]): list of region proposals.
+
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+
+ gt_labels (list[Tensor]): class indices corresponding to each box
+
+ gt_bboxes_ignore (None, list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+
+ gt_masks (None, Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ gt_semantic_seg (None, list[Tensor]): semantic segmentation masks
+ used if the architecture supports semantic segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ losses = dict()
+
+ # semantic segmentation branch
+ if self.with_semantic:
+ semantic_pred, semantic_feat = self.semantic_head(x)
+ loss_seg = self.semantic_head.loss(semantic_pred, gt_semantic_seg)
+ losses['loss_semantic_seg'] = loss_seg
+ else:
+ semantic_feat = None
+
+ # global context branch
+ if self.with_glbctx:
+ mc_pred, glbctx_feat = self.glbctx_head(x)
+ loss_glbctx = self.glbctx_head.loss(mc_pred, gt_labels)
+ losses['loss_glbctx'] = loss_glbctx
+ else:
+ glbctx_feat = None
+
+ for i in range(self.num_stages):
+ self.current_stage = i
+ rcnn_train_cfg = self.train_cfg[i]
+ lw = self.stage_loss_weights[i]
+
+ # assign gts and sample proposals
+ sampling_results = []
+ bbox_assigner = self.bbox_assigner[i]
+ bbox_sampler = self.bbox_sampler[i]
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+
+ for j in range(num_imgs):
+ assign_result = bbox_assigner.assign(proposal_list[j],
+ gt_bboxes[j],
+ gt_bboxes_ignore[j],
+ gt_labels[j])
+ sampling_result = bbox_sampler.sample(
+ assign_result,
+ proposal_list[j],
+ gt_bboxes[j],
+ gt_labels[j],
+ feats=[lvl_feat[j][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ bbox_results = \
+ self._bbox_forward_train(
+ i, x, sampling_results, gt_bboxes, gt_labels,
+ rcnn_train_cfg, semantic_feat, glbctx_feat)
+ roi_labels = bbox_results['bbox_targets'][0]
+
+ for name, value in bbox_results['loss_bbox'].items():
+ losses[f's{i}.{name}'] = (
+ value * lw if 'loss' in name else value)
+
+ # refine boxes
+ if i < self.num_stages - 1:
+ pos_is_gts = [res.pos_is_gt for res in sampling_results]
+ with torch.no_grad():
+ proposal_list = self.bbox_head[i].refine_bboxes(
+ bbox_results['rois'], roi_labels,
+ bbox_results['bbox_pred'], pos_is_gts, img_metas)
+
+ if self.with_feat_relay:
+ relayed_feat = self._slice_pos_feats(bbox_results['relayed_feat'],
+ sampling_results)
+ relayed_feat = self.feat_relay_head(relayed_feat)
+ else:
+ relayed_feat = None
+
+ mask_results = self._mask_forward_train(x, sampling_results, gt_masks,
+ rcnn_train_cfg, semantic_feat,
+ glbctx_feat, relayed_feat)
+ mask_lw = sum(self.stage_loss_weights)
+ losses['loss_mask'] = mask_lw * mask_results['loss_mask']
+
+ return losses
+
+ def simple_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test without augmentation."""
+ if self.with_semantic:
+ _, semantic_feat = self.semantic_head(x)
+ else:
+ semantic_feat = None
+
+ if self.with_glbctx:
+ mc_pred, glbctx_feat = self.glbctx_head(x)
+ else:
+ glbctx_feat = None
+
+ num_imgs = len(proposal_list)
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+ rcnn_test_cfg = self.test_cfg
+
+ rois = bbox2roi(proposal_list)
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ # split batch bbox prediction back to each image
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+ num_proposals_per_img = tuple(len(p) for p in proposal_list)
+ rois = rois.split(num_proposals_per_img, 0)
+ cls_score = cls_score.split(num_proposals_per_img, 0)
+ bbox_pred = bbox_pred.split(num_proposals_per_img, 0)
+ ms_scores.append(cls_score)
+
+ if i < self.num_stages - 1:
+ bbox_label = [s[:, :-1].argmax(dim=1) for s in cls_score]
+ rois = torch.cat([
+ bbox_head.regress_by_class(rois[i], bbox_label[i],
+ bbox_pred[i], img_metas[i])
+ for i in range(num_imgs)
+ ])
+
+ # average scores of each image by stages
+ cls_score = [
+ sum([score[i] for score in ms_scores]) / float(len(ms_scores))
+ for i in range(num_imgs)
+ ]
+
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(num_imgs):
+ det_bbox, det_label = self.bbox_head[-1].get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_pred[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ det_bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head[-1].num_classes)
+ for i in range(num_imgs)
+ ]
+
+ if self.with_mask:
+ if all(det_bbox.shape[0] == 0 for det_bbox in det_bboxes):
+ mask_classes = self.mask_head.num_classes
+ det_segm_results = [[[] for _ in range(mask_classes)]
+ for _ in range(num_imgs)]
+ else:
+ if rescale and not isinstance(scale_factors[0], float):
+ scale_factors = [
+ torch.from_numpy(scale_factor).to(det_bboxes[0].device)
+ for scale_factor in scale_factors
+ ]
+ _bboxes = [
+ det_bboxes[i][:, :4] *
+ scale_factors[i] if rescale else det_bboxes[i]
+ for i in range(num_imgs)
+ ]
+ mask_rois = bbox2roi(_bboxes)
+
+ # get relay feature on mask_rois
+ bbox_results = self._bbox_forward(
+ -1,
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ relayed_feat = bbox_results['relayed_feat']
+ relayed_feat = self.feat_relay_head(relayed_feat)
+
+ mask_results = self._mask_forward(
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+ mask_pred = mask_results['mask_pred']
+
+ # split batch mask prediction back to each image
+ num_bbox_per_img = tuple(len(_bbox) for _bbox in _bboxes)
+ mask_preds = mask_pred.split(num_bbox_per_img, 0)
+
+ # apply mask post-processing to each image individually
+ det_segm_results = []
+ for i in range(num_imgs):
+ if det_bboxes[i].shape[0] == 0:
+ det_segm_results.append(
+ [[] for _ in range(self.mask_head.num_classes)])
+ else:
+ segm_result = self.mask_head.get_seg_masks(
+ mask_preds[i], _bboxes[i], det_labels[i],
+ self.test_cfg, ori_shapes[i], scale_factors[i],
+ rescale)
+ det_segm_results.append(segm_result)
+
+ # return results
+ if self.with_mask:
+ return list(zip(det_bbox_results, det_segm_results))
+ else:
+ return det_bbox_results
+
+ def aug_test(self, img_feats, proposal_list, img_metas, rescale=False):
+ if self.with_semantic:
+ semantic_feats = [
+ self.semantic_head(feat)[1] for feat in img_feats
+ ]
+ else:
+ semantic_feats = [None] * len(img_metas)
+
+ if self.with_glbctx:
+ glbctx_feats = [self.glbctx_head(feat)[1] for feat in img_feats]
+ else:
+ glbctx_feats = [None] * len(img_metas)
+
+ rcnn_test_cfg = self.test_cfg
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta, semantic_feat, glbctx_feat in zip(
+ img_feats, img_metas, semantic_feats, glbctx_feats):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip)
+ # "ms" in variable names means multi-stage
+ ms_scores = []
+
+ rois = bbox2roi([proposals])
+ for i in range(self.num_stages):
+ bbox_head = self.bbox_head[i]
+ bbox_results = self._bbox_forward(
+ i,
+ x,
+ rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ ms_scores.append(bbox_results['cls_score'])
+ if i < self.num_stages - 1:
+ bbox_label = bbox_results['cls_score'].argmax(dim=1)
+ rois = bbox_head.regress_by_class(
+ rois, bbox_label, bbox_results['bbox_pred'],
+ img_meta[0])
+
+ cls_score = sum(ms_scores) / float(len(ms_scores))
+ bboxes, scores = self.bbox_head[-1].get_bboxes(
+ rois,
+ cls_score,
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+
+ det_bbox_results = bbox2result(det_bboxes, det_labels,
+ self.bbox_head[-1].num_classes)
+
+ if self.with_mask:
+ if det_bboxes.shape[0] == 0:
+ det_segm_results = [[]
+ for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta, semantic_feat, glbctx_feat in zip(
+ img_feats, img_metas, semantic_feats, glbctx_feats):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip)
+ mask_rois = bbox2roi([_bboxes])
+ # get relay feature on mask_rois
+ bbox_results = self._bbox_forward(
+ -1,
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat)
+ relayed_feat = bbox_results['relayed_feat']
+ relayed_feat = self.feat_relay_head(relayed_feat)
+ mask_results = self._mask_forward(
+ x,
+ mask_rois,
+ semantic_feat=semantic_feat,
+ glbctx_feat=glbctx_feat,
+ relayed_feat=relayed_feat)
+ mask_pred = mask_results['mask_pred']
+ aug_masks.append(mask_pred.sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas,
+ self.test_cfg)
+ ori_shape = img_metas[0][0]['ori_shape']
+ det_segm_results = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ rcnn_test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return [(det_bbox_results, det_segm_results)]
+ else:
+ return [det_bbox_results]
diff --git a/mmdet/models/roi_heads/shared_heads/__init__.py b/mmdet/models/roi_heads/shared_heads/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbe70145b8bf7c304370f725f5afa8db98666679
--- /dev/null
+++ b/mmdet/models/roi_heads/shared_heads/__init__.py
@@ -0,0 +1,3 @@
+from .res_layer import ResLayer
+
+__all__ = ['ResLayer']
diff --git a/mmdet/models/roi_heads/shared_heads/res_layer.py b/mmdet/models/roi_heads/shared_heads/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5c343258b079a0dd832d4f999c18d002b06efac
--- /dev/null
+++ b/mmdet/models/roi_heads/shared_heads/res_layer.py
@@ -0,0 +1,77 @@
+import torch.nn as nn
+from mmcv.cnn import constant_init, kaiming_init
+from mmcv.runner import auto_fp16, load_checkpoint
+
+from mmdet.models.backbones import ResNet
+from mmdet.models.builder import SHARED_HEADS
+from mmdet.models.utils import ResLayer as _ResLayer
+from mmdet.utils import get_root_logger
+
+
+@SHARED_HEADS.register_module()
+class ResLayer(nn.Module):
+
+ def __init__(self,
+ depth,
+ stage=3,
+ stride=2,
+ dilation=1,
+ style='pytorch',
+ norm_cfg=dict(type='BN', requires_grad=True),
+ norm_eval=True,
+ with_cp=False,
+ dcn=None):
+ super(ResLayer, self).__init__()
+ self.norm_eval = norm_eval
+ self.norm_cfg = norm_cfg
+ self.stage = stage
+ self.fp16_enabled = False
+ block, stage_blocks = ResNet.arch_settings[depth]
+ stage_block = stage_blocks[stage]
+ planes = 64 * 2**stage
+ inplanes = 64 * 2**(stage - 1) * block.expansion
+
+ res_layer = _ResLayer(
+ block,
+ inplanes,
+ planes,
+ stage_block,
+ stride=stride,
+ dilation=dilation,
+ style=style,
+ with_cp=with_cp,
+ norm_cfg=self.norm_cfg,
+ dcn=dcn)
+ self.add_module(f'layer{stage + 1}', res_layer)
+
+ def init_weights(self, pretrained=None):
+ """Initialize the weights in the module.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if isinstance(pretrained, str):
+ logger = get_root_logger()
+ load_checkpoint(self, pretrained, strict=False, logger=logger)
+ elif pretrained is None:
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ kaiming_init(m)
+ elif isinstance(m, nn.BatchNorm2d):
+ constant_init(m, 1)
+ else:
+ raise TypeError('pretrained must be a str or None')
+
+ @auto_fp16()
+ def forward(self, x):
+ res_layer = getattr(self, f'layer{self.stage + 1}')
+ out = res_layer(x)
+ return out
+
+ def train(self, mode=True):
+ super(ResLayer, self).train(mode)
+ if self.norm_eval:
+ for m in self.modules():
+ if isinstance(m, nn.BatchNorm2d):
+ m.eval()
diff --git a/mmdet/models/roi_heads/sparse_roi_head.py b/mmdet/models/roi_heads/sparse_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d85ebc4698f3fc0b974e680c343f91deff4bb50
--- /dev/null
+++ b/mmdet/models/roi_heads/sparse_roi_head.py
@@ -0,0 +1,311 @@
+import torch
+
+from mmdet.core import bbox2result, bbox2roi, bbox_xyxy_to_cxcywh
+from mmdet.core.bbox.samplers import PseudoSampler
+from ..builder import HEADS
+from .cascade_roi_head import CascadeRoIHead
+
+
+@HEADS.register_module()
+class SparseRoIHead(CascadeRoIHead):
+ r"""The RoIHead for `Sparse R-CNN: End-to-End Object Detection with
+ Learnable Proposals `_
+
+ Args:
+ num_stages (int): Number of stage whole iterative process.
+ Defaults to 6.
+ stage_loss_weights (Tuple[float]): The loss
+ weight of each stage. By default all stages have
+ the same weight 1.
+ bbox_roi_extractor (dict): Config of box roi extractor.
+ bbox_head (dict): Config of box head.
+ train_cfg (dict, optional): Configuration information in train stage.
+ Defaults to None.
+ test_cfg (dict, optional): Configuration information in test stage.
+ Defaults to None.
+
+ """
+
+ def __init__(self,
+ num_stages=6,
+ stage_loss_weights=(1, 1, 1, 1, 1, 1),
+ proposal_feature_channel=256,
+ bbox_roi_extractor=dict(
+ type='SingleRoIExtractor',
+ roi_layer=dict(
+ type='RoIAlign', output_size=7, sampling_ratio=2),
+ out_channels=256,
+ featmap_strides=[4, 8, 16, 32]),
+ bbox_head=dict(
+ type='DIIHead',
+ num_classes=80,
+ num_fcs=2,
+ num_heads=8,
+ num_cls_fcs=1,
+ num_reg_fcs=3,
+ feedforward_channels=2048,
+ hidden_channels=256,
+ dropout=0.0,
+ roi_feat_size=7,
+ ffn_act_cfg=dict(type='ReLU', inplace=True)),
+ train_cfg=None,
+ test_cfg=None):
+ assert bbox_roi_extractor is not None
+ assert bbox_head is not None
+ assert len(stage_loss_weights) == num_stages
+ self.num_stages = num_stages
+ self.stage_loss_weights = stage_loss_weights
+ self.proposal_feature_channel = proposal_feature_channel
+ super(SparseRoIHead, self).__init__(
+ num_stages,
+ stage_loss_weights,
+ bbox_roi_extractor=bbox_roi_extractor,
+ bbox_head=bbox_head,
+ train_cfg=train_cfg,
+ test_cfg=test_cfg)
+ # train_cfg would be None when run the test.py
+ if train_cfg is not None:
+ for stage in range(num_stages):
+ assert isinstance(self.bbox_sampler[stage], PseudoSampler), \
+ 'Sparse R-CNN only support `PseudoSampler`'
+
+ def _bbox_forward(self, stage, x, rois, object_feats, img_metas):
+ """Box head forward function used in both training and testing. Returns
+ all regression, classification results and a intermediate feature.
+
+ Args:
+ stage (int): The index of current stage in
+ iterative process.
+ x (List[Tensor]): List of FPN features
+ rois (Tensor): Rois in total batch. With shape (num_proposal, 5).
+ the last dimension 5 represents (img_index, x1, y1, x2, y2).
+ object_feats (Tensor): The object feature extracted from
+ the previous stage.
+ img_metas (dict): meta information of images.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of bbox head outputs,
+ Containing the following results:
+
+ - cls_score (Tensor): The score of each class, has
+ shape (batch_size, num_proposals, num_classes)
+ when use focal loss or
+ (batch_size, num_proposals, num_classes+1)
+ otherwise.
+ - decode_bbox_pred (Tensor): The regression results
+ with shape (batch_size, num_proposal, 4).
+ The last dimension 4 represents
+ [tl_x, tl_y, br_x, br_y].
+ - object_feats (Tensor): The object feature extracted
+ from current stage
+ - detach_cls_score_list (list[Tensor]): The detached
+ classification results, length is batch_size, and
+ each tensor has shape (num_proposal, num_classes).
+ - detach_proposal_list (list[tensor]): The detached
+ regression results, length is batch_size, and each
+ tensor has shape (num_proposal, 4). The last
+ dimension 4 represents [tl_x, tl_y, br_x, br_y].
+ """
+ num_imgs = len(img_metas)
+ bbox_roi_extractor = self.bbox_roi_extractor[stage]
+ bbox_head = self.bbox_head[stage]
+ bbox_feats = bbox_roi_extractor(x[:bbox_roi_extractor.num_inputs],
+ rois)
+ cls_score, bbox_pred, object_feats = bbox_head(bbox_feats,
+ object_feats)
+ proposal_list = self.bbox_head[stage].refine_bboxes(
+ rois,
+ rois.new_zeros(len(rois)), # dummy arg
+ bbox_pred.view(-1, bbox_pred.size(-1)),
+ [rois.new_zeros(object_feats.size(1)) for _ in range(num_imgs)],
+ img_metas)
+ bbox_results = dict(
+ cls_score=cls_score,
+ decode_bbox_pred=torch.cat(proposal_list),
+ object_feats=object_feats,
+ # detach then use it in label assign
+ detach_cls_score_list=[
+ cls_score[i].detach() for i in range(num_imgs)
+ ],
+ detach_proposal_list=[item.detach() for item in proposal_list])
+
+ return bbox_results
+
+ def forward_train(self,
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ imgs_whwh=None,
+ gt_masks=None):
+ """Forward function in training stage.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ proposals (Tensor): Decoded proposal bboxes, has shape
+ (batch_size, num_proposals, 4)
+ proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel)
+ img_metas (list[dict]): list of image info dict where
+ each dict has: 'img_shape', 'scale_factor', 'flip',
+ and may also contain 'filename', 'ori_shape',
+ 'pad_shape', and 'img_norm_cfg'. For details on the
+ values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
+ the dimension means
+ [img_width,img_height, img_width, img_height].
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components of all stage.
+ """
+
+ num_imgs = len(img_metas)
+ num_proposals = proposal_boxes.size(1)
+ imgs_whwh = imgs_whwh.repeat(1, num_proposals, 1)
+ all_stage_bbox_results = []
+ proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
+ object_feats = proposal_features
+ all_stage_loss = {}
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+ all_stage_bbox_results.append(bbox_results)
+ if gt_bboxes_ignore is None:
+ # TODO support ignore
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ cls_pred_list = bbox_results['detach_cls_score_list']
+ proposal_list = bbox_results['detach_proposal_list']
+ for i in range(num_imgs):
+ normalize_bbox_ccwh = bbox_xyxy_to_cxcywh(proposal_list[i] /
+ imgs_whwh[i])
+ assign_result = self.bbox_assigner[stage].assign(
+ normalize_bbox_ccwh, cls_pred_list[i], gt_bboxes[i],
+ gt_labels[i], img_metas[i])
+ sampling_result = self.bbox_sampler[stage].sample(
+ assign_result, proposal_list[i], gt_bboxes[i])
+ sampling_results.append(sampling_result)
+ bbox_targets = self.bbox_head[stage].get_targets(
+ sampling_results, gt_bboxes, gt_labels, self.train_cfg[stage],
+ True)
+ cls_score = bbox_results['cls_score']
+ decode_bbox_pred = bbox_results['decode_bbox_pred']
+
+ single_stage_loss = self.bbox_head[stage].loss(
+ cls_score.view(-1, cls_score.size(-1)),
+ decode_bbox_pred.view(-1, 4),
+ *bbox_targets,
+ imgs_whwh=imgs_whwh)
+ for key, value in single_stage_loss.items():
+ all_stage_loss[f'stage{stage}_{key}'] = value * \
+ self.stage_loss_weights[stage]
+ object_feats = bbox_results['object_feats']
+
+ return all_stage_loss
+
+ def simple_test(self,
+ x,
+ proposal_boxes,
+ proposal_features,
+ img_metas,
+ imgs_whwh,
+ rescale=False):
+ """Test without augmentation.
+
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ proposal_boxes (Tensor): Decoded proposal bboxes, has shape
+ (batch_size, num_proposals, 4)
+ proposal_features (Tensor): Expanded proposal
+ features, has shape
+ (batch_size, num_proposals, proposal_feature_channel)
+ img_metas (dict): meta information of images.
+ imgs_whwh (Tensor): Tensor with shape (batch_size, 4),
+ the dimension means
+ [img_width,img_height, img_width, img_height].
+ rescale (bool): If True, return boxes in original image
+ space. Defaults to False.
+
+ Returns:
+ bbox_results (list[tuple[np.ndarray]]): \
+ [[cls1_det, cls2_det, ...], ...]. \
+ The outer list indicates images, and the inner \
+ list indicates per-class detected bboxes. The \
+ np.ndarray has shape (num_det, 5) and the last \
+ dimension 5 represents (x1, y1, x2, y2, score).
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ # Decode initial proposals
+ num_imgs = len(img_metas)
+ proposal_list = [proposal_boxes[i] for i in range(num_imgs)]
+ object_feats = proposal_features
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+ object_feats = bbox_results['object_feats']
+ cls_score = bbox_results['cls_score']
+ proposal_list = bbox_results['detach_proposal_list']
+
+ num_classes = self.bbox_head[-1].num_classes
+ det_bboxes = []
+ det_labels = []
+
+ if self.bbox_head[-1].loss_cls.use_sigmoid:
+ cls_score = cls_score.sigmoid()
+ else:
+ cls_score = cls_score.softmax(-1)[..., :-1]
+
+ for img_id in range(num_imgs):
+ cls_score_per_img = cls_score[img_id]
+ scores_per_img, topk_indices = cls_score_per_img.flatten(
+ 0, 1).topk(
+ self.test_cfg.max_per_img, sorted=False)
+ labels_per_img = topk_indices % num_classes
+ bbox_pred_per_img = proposal_list[img_id][topk_indices //
+ num_classes]
+ if rescale:
+ scale_factor = img_metas[img_id]['scale_factor']
+ bbox_pred_per_img /= bbox_pred_per_img.new_tensor(scale_factor)
+ det_bboxes.append(
+ torch.cat([bbox_pred_per_img, scores_per_img[:, None]], dim=1))
+ det_labels.append(labels_per_img)
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i], num_classes)
+ for i in range(num_imgs)
+ ]
+
+ return bbox_results
+
+ def aug_test(self, features, proposal_list, img_metas, rescale=False):
+ raise NotImplementedError('Sparse R-CNN does not support `aug_test`')
+
+ def forward_dummy(self, x, proposal_boxes, proposal_features, img_metas):
+ """Dummy forward function when do the flops computing."""
+ all_stage_bbox_results = []
+ proposal_list = [proposal_boxes[i] for i in range(len(proposal_boxes))]
+ object_feats = proposal_features
+ if self.with_bbox:
+ for stage in range(self.num_stages):
+ rois = bbox2roi(proposal_list)
+ bbox_results = self._bbox_forward(stage, x, rois, object_feats,
+ img_metas)
+
+ all_stage_bbox_results.append(bbox_results)
+ proposal_list = bbox_results['detach_proposal_list']
+ object_feats = bbox_results['object_feats']
+ return all_stage_bbox_results
diff --git a/mmdet/models/roi_heads/standard_roi_head.py b/mmdet/models/roi_heads/standard_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d5e163e90b4e2bba6ee1b04a7d8989a52e07fa3
--- /dev/null
+++ b/mmdet/models/roi_heads/standard_roi_head.py
@@ -0,0 +1,306 @@
+import torch
+
+from mmdet.core import bbox2result, bbox2roi, build_assigner, build_sampler
+from ..builder import HEADS, build_head, build_roi_extractor
+from .base_roi_head import BaseRoIHead
+from .test_mixins import BBoxTestMixin, MaskTestMixin
+
+
+@HEADS.register_module()
+class StandardRoIHead(BaseRoIHead, BBoxTestMixin, MaskTestMixin):
+ """Simplest base roi head including one bbox head and one mask head."""
+
+ def init_assigner_sampler(self):
+ """Initialize assigner and sampler."""
+ self.bbox_assigner = None
+ self.bbox_sampler = None
+ if self.train_cfg:
+ self.bbox_assigner = build_assigner(self.train_cfg.assigner)
+ self.bbox_sampler = build_sampler(
+ self.train_cfg.sampler, context=self)
+
+ def init_bbox_head(self, bbox_roi_extractor, bbox_head):
+ """Initialize ``bbox_head``"""
+ self.bbox_roi_extractor = build_roi_extractor(bbox_roi_extractor)
+ self.bbox_head = build_head(bbox_head)
+
+ def init_mask_head(self, mask_roi_extractor, mask_head):
+ """Initialize ``mask_head``"""
+ if mask_roi_extractor is not None:
+ self.mask_roi_extractor = build_roi_extractor(mask_roi_extractor)
+ self.share_roi_extractor = False
+ else:
+ self.share_roi_extractor = True
+ self.mask_roi_extractor = self.bbox_roi_extractor
+ self.mask_head = build_head(mask_head)
+
+ def init_gan_head(self, gan_roi_extractor, gan_head):
+ """Initialize ``mask_head``"""
+ if gan_roi_extractor is not None:
+ self.gan_roi_extractor = build_roi_extractor(gan_roi_extractor)
+ self.share_roi_extractor = False
+ else:
+ self.share_roi_extractor = True
+ self.gan_roi_extractor = self.bbox_roi_extractor
+ self.gan_head = build_head(gan_head)
+
+
+ def init_weights(self, pretrained):
+ """Initialize the weights in head.
+
+ Args:
+ pretrained (str, optional): Path to pre-trained weights.
+ Defaults to None.
+ """
+ if self.with_shared_head:
+ self.shared_head.init_weights(pretrained=pretrained)
+ if self.with_bbox:
+ self.bbox_roi_extractor.init_weights()
+ self.bbox_head.init_weights()
+ if self.with_mask:
+ self.mask_head.init_weights()
+ if not self.share_roi_extractor:
+ self.mask_roi_extractor.init_weights()
+
+ def forward_dummy(self, x, proposals):
+ """Dummy forward function."""
+ # bbox head
+ outs = ()
+ rois = bbox2roi([proposals])
+ if self.with_bbox:
+ bbox_results = self._bbox_forward(x, rois)
+ outs = outs + (bbox_results['cls_score'],
+ bbox_results['bbox_pred'])
+ # mask head
+ if self.with_mask:
+ mask_rois = rois[:100]
+ mask_results = self._mask_forward(x, mask_rois)
+ outs = outs + (mask_results['mask_pred'], )
+ return outs
+
+ def forward_train(self,
+ x,
+ img_metas,
+ proposal_list,
+ gt_bboxes,
+ gt_labels,
+ gt_bboxes_ignore=None,
+ gt_masks=None):
+ """
+ Args:
+ x (list[Tensor]): list of multi-level img features.
+ img_metas (list[dict]): list of image info dict where each dict
+ has: 'img_shape', 'scale_factor', 'flip', and may also contain
+ 'filename', 'ori_shape', 'pad_shape', and 'img_norm_cfg'.
+ For details on the values of these keys see
+ `mmdet/datasets/pipelines/formatting.py:Collect`.
+ proposals (list[Tensors]): list of region proposals.
+ gt_bboxes (list[Tensor]): Ground truth bboxes for each image with
+ shape (num_gts, 4) in [tl_x, tl_y, br_x, br_y] format.
+ gt_labels (list[Tensor]): class indices corresponding to each box
+ gt_bboxes_ignore (None | list[Tensor]): specify which bounding
+ boxes can be ignored when computing the loss.
+ gt_masks (None | Tensor) : true segmentation masks for each box
+ used if the architecture supports a segmentation task.
+
+ Returns:
+ dict[str, Tensor]: a dictionary of loss components
+ """
+ # assign gts and sample proposals
+ if self.with_bbox or self.with_mask:
+ num_imgs = len(img_metas)
+ if gt_bboxes_ignore is None:
+ gt_bboxes_ignore = [None for _ in range(num_imgs)]
+ sampling_results = []
+ for i in range(num_imgs):
+ assign_result = self.bbox_assigner.assign(
+ proposal_list[i], gt_bboxes[i], gt_bboxes_ignore[i],
+ gt_labels[i])
+ sampling_result = self.bbox_sampler.sample(
+ assign_result,
+ proposal_list[i],
+ gt_bboxes[i],
+ gt_labels[i],
+ feats=[lvl_feat[i][None] for lvl_feat in x])
+ sampling_results.append(sampling_result)
+
+ losses = dict()
+ # bbox head forward and loss
+ if self.with_bbox:
+ bbox_results = self._bbox_forward_train(x, sampling_results,
+ gt_bboxes, gt_labels,
+ img_metas)
+ losses.update(bbox_results['loss_bbox'])
+
+ # mask head forward and loss
+ if self.with_mask:
+ mask_results = self._mask_forward_train(x, sampling_results,
+ bbox_results['bbox_feats'],
+ gt_masks, img_metas)
+ losses.update(mask_results['loss_mask'])
+
+ return losses
+
+ def _bbox_forward(self, x, rois):
+ """Box head forward function used in both training and testing."""
+ # TODO: a more flexible way to decide which feature maps to use
+ bbox_feats = self.bbox_roi_extractor(
+ x[:self.bbox_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ bbox_feats = self.shared_head(bbox_feats)
+ cls_score, bbox_pred = self.bbox_head(bbox_feats)
+
+ bbox_results = dict(
+ cls_score=cls_score, bbox_pred=bbox_pred, bbox_feats=bbox_feats)
+ return bbox_results
+
+ def _bbox_forward_train(self, x, sampling_results, gt_bboxes, gt_labels,
+ img_metas):
+ """Run forward function and calculate loss for box head in training."""
+ rois = bbox2roi([res.bboxes for res in sampling_results])
+ bbox_results = self._bbox_forward(x, rois)
+
+ bbox_targets = self.bbox_head.get_targets(sampling_results, gt_bboxes,
+ gt_labels, self.train_cfg)
+ loss_bbox = self.bbox_head.loss(bbox_results['cls_score'],
+ bbox_results['bbox_pred'], rois,
+ *bbox_targets)
+
+ bbox_results.update(loss_bbox=loss_bbox)
+ return bbox_results
+
+ def _mask_forward_train(self, x, sampling_results, bbox_feats, gt_masks,
+ img_metas):
+ """Run forward function and calculate loss for mask head in
+ training."""
+ if not self.share_roi_extractor:
+ pos_rois = bbox2roi([res.pos_bboxes for res in sampling_results])
+ mask_results = self._mask_forward(x, pos_rois)
+ else:
+ pos_inds = []
+ device = bbox_feats.device
+ for res in sampling_results:
+ pos_inds.append(
+ torch.ones(
+ res.pos_bboxes.shape[0],
+ device=device,
+ dtype=torch.uint8))
+ pos_inds.append(
+ torch.zeros(
+ res.neg_bboxes.shape[0],
+ device=device,
+ dtype=torch.uint8))
+ pos_inds = torch.cat(pos_inds)
+
+ mask_results = self._mask_forward(
+ x, pos_inds=pos_inds, bbox_feats=bbox_feats)
+
+ mask_targets = self.mask_head.get_targets(sampling_results, gt_masks,
+ self.train_cfg)
+ pos_labels = torch.cat([res.pos_gt_labels for res in sampling_results])
+ loss_mask = self.mask_head.loss(mask_results['mask_pred'],
+ mask_targets, pos_labels)
+
+ mask_results.update(loss_mask=loss_mask, mask_targets=mask_targets)
+ return mask_results
+
+ def _mask_forward(self, x, rois=None, pos_inds=None, bbox_feats=None):
+ """Mask head forward function used in both training and testing."""
+ assert ((rois is not None) ^
+ (pos_inds is not None and bbox_feats is not None))
+ if rois is not None:
+ mask_feats = self.mask_roi_extractor(
+ x[:self.mask_roi_extractor.num_inputs], rois)
+ if self.with_shared_head:
+ mask_feats = self.shared_head(mask_feats)
+ else:
+ assert bbox_feats is not None
+ mask_feats = bbox_feats[pos_inds]
+
+ mask_pred = self.mask_head(mask_feats)
+ mask_results = dict(mask_pred=mask_pred, mask_feats=mask_feats)
+ return mask_results
+
+ async def async_simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Async test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = await self.async_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+ bbox_results = bbox2result(det_bboxes, det_labels,
+ self.bbox_head.num_classes)
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = await self.async_test_mask(
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=rescale,
+ mask_test_cfg=self.test_cfg.get('mask'))
+ return bbox_results, segm_results
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation."""
+ assert self.with_bbox, 'Bbox head must be implemented.'
+
+ det_bboxes, det_labels = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+ if torch.onnx.is_in_onnx_export():
+ if self.with_mask:
+ segm_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return det_bboxes, det_labels, segm_results
+ else:
+ return det_bboxes, det_labels
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head.num_classes)
+ for i in range(len(det_bboxes))
+ ]
+
+ if not self.with_mask:
+ return bbox_results
+ else:
+ segm_results = self.simple_test_mask(
+ x, img_metas, det_bboxes, det_labels, rescale=rescale)
+ return list(zip(bbox_results, segm_results))
+
+ def aug_test(self, x, proposal_list, img_metas, rescale=False):
+ """Test with augmentations.
+
+ If rescale is False, then returned bboxes and masks will fit the scale
+ of imgs[0].
+ """
+ det_bboxes, det_labels = self.aug_test_bboxes(x, img_metas,
+ proposal_list,
+ self.test_cfg)
+
+ if rescale:
+ _det_bboxes = det_bboxes
+ else:
+ _det_bboxes = det_bboxes.clone()
+ _det_bboxes[:, :4] *= det_bboxes.new_tensor(
+ img_metas[0][0]['scale_factor'])
+ bbox_results = bbox2result(_det_bboxes, det_labels,
+ self.bbox_head.num_classes)
+
+ # det_bboxes always keep the original scale
+ if self.with_mask:
+ segm_results = self.aug_test_mask(x, img_metas, det_bboxes,
+ det_labels)
+ return [(bbox_results, segm_results)]
+ else:
+ return [bbox_results]
diff --git a/mmdet/models/roi_heads/test_mixins.py b/mmdet/models/roi_heads/test_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..c28ed61deb946f0ffca70733fb7ddf84d1aec885
--- /dev/null
+++ b/mmdet/models/roi_heads/test_mixins.py
@@ -0,0 +1,368 @@
+import logging
+import sys
+
+import torch
+
+from mmdet.core import (bbox2roi, bbox_mapping, merge_aug_bboxes,
+ merge_aug_masks, multiclass_nms)
+
+logger = logging.getLogger(__name__)
+
+if sys.version_info >= (3, 7):
+ from mmdet.utils.contextmanagers import completed
+
+
+class BBoxTestMixin(object):
+
+ if sys.version_info >= (3, 7):
+
+ async def async_test_bboxes(self,
+ x,
+ img_metas,
+ proposals,
+ rcnn_test_cfg,
+ rescale=False,
+ bbox_semaphore=None,
+ global_lock=None):
+ """Asynchronized test for box head without augmentation."""
+ rois = bbox2roi(proposals)
+ roi_feats = self.bbox_roi_extractor(
+ x[:len(self.bbox_roi_extractor.featmap_strides)], rois)
+ if self.with_shared_head:
+ roi_feats = self.shared_head(roi_feats)
+ sleep_interval = rcnn_test_cfg.get('async_sleep_interval', 0.017)
+
+ async with completed(
+ __name__, 'bbox_head_forward',
+ sleep_interval=sleep_interval):
+ cls_score, bbox_pred = self.bbox_head(roi_feats)
+
+ img_shape = img_metas[0]['img_shape']
+ scale_factor = img_metas[0]['scale_factor']
+ det_bboxes, det_labels = self.bbox_head.get_bboxes(
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shape,
+ scale_factor,
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ return det_bboxes, det_labels
+
+ def simple_test_bboxes(self,
+ x,
+ img_metas,
+ proposals,
+ rcnn_test_cfg,
+ rescale=False):
+ """Test only det bboxes without augmentation.
+
+ Args:
+ x (tuple[Tensor]): Feature maps of all scale level.
+ img_metas (list[dict]): Image meta info.
+ proposals (Tensor or List[Tensor]): Region proposals.
+ rcnn_test_cfg (obj:`ConfigDict`): `test_cfg` of R-CNN.
+ rescale (bool): If True, return boxes in original image space.
+ Default: False.
+
+ Returns:
+ tuple[list[Tensor], list[Tensor]]: The first list contains
+ the boxes of the corresponding image in a batch, each
+ tensor has the shape (num_boxes, 5) and last dimension
+ 5 represent (tl_x, tl_y, br_x, br_y, score). Each Tensor
+ in the second list is the labels with shape (num_boxes, ).
+ The length of both lists should be equal to batch_size.
+ """
+ # get origin input shape to support onnx dynamic input shape
+ if torch.onnx.is_in_onnx_export():
+ assert len(
+ img_metas
+ ) == 1, 'Only support one input image while in exporting to ONNX'
+ img_shapes = img_metas[0]['img_shape_for_onnx']
+ else:
+ img_shapes = tuple(meta['img_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # The length of proposals of different batches may be different.
+ # In order to form a batch, a padding operation is required.
+ if isinstance(proposals, list):
+ # padding to form a batch
+ max_size = max([proposal.size(0) for proposal in proposals])
+ for i, proposal in enumerate(proposals):
+ supplement = proposal.new_full(
+ (max_size - proposal.size(0), proposal.size(1)), 0)
+ proposals[i] = torch.cat((supplement, proposal), dim=0)
+ rois = torch.stack(proposals, dim=0)
+ else:
+ rois = proposals
+
+ batch_index = torch.arange(
+ rois.size(0), device=rois.device).float().view(-1, 1, 1).expand(
+ rois.size(0), rois.size(1), 1)
+ rois = torch.cat([batch_index, rois[..., :4]], dim=-1)
+ batch_size = rois.shape[0]
+ num_proposals_per_img = rois.shape[1]
+
+ # Eliminate the batch dimension
+ rois = rois.view(-1, 5)
+ bbox_results = self._bbox_forward(x, rois)
+ cls_score = bbox_results['cls_score']
+ bbox_pred = bbox_results['bbox_pred']
+
+ # Recover the batch dimension
+ rois = rois.reshape(batch_size, num_proposals_per_img, -1)
+ cls_score = cls_score.reshape(batch_size, num_proposals_per_img, -1)
+
+ if not torch.onnx.is_in_onnx_export():
+ # remove padding
+ supplement_mask = rois[..., -1] == 0
+ cls_score[supplement_mask, :] = 0
+
+ # bbox_pred would be None in some detector when with_reg is False,
+ # e.g. Grid R-CNN.
+ if bbox_pred is not None:
+ # the bbox prediction of some detectors like SABL is not Tensor
+ if isinstance(bbox_pred, torch.Tensor):
+ bbox_pred = bbox_pred.reshape(batch_size,
+ num_proposals_per_img, -1)
+ if not torch.onnx.is_in_onnx_export():
+ bbox_pred[supplement_mask, :] = 0
+ else:
+ # TODO: Looking forward to a better way
+ # For SABL
+ bbox_preds = self.bbox_head.bbox_pred_split(
+ bbox_pred, num_proposals_per_img)
+ # apply bbox post-processing to each image individually
+ det_bboxes = []
+ det_labels = []
+ for i in range(len(proposals)):
+ # remove padding
+ supplement_mask = proposals[i][..., -1] == 0
+ for bbox in bbox_preds[i]:
+ bbox[supplement_mask] = 0
+ det_bbox, det_label = self.bbox_head.get_bboxes(
+ rois[i],
+ cls_score[i],
+ bbox_preds[i],
+ img_shapes[i],
+ scale_factors[i],
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+ det_bboxes.append(det_bbox)
+ det_labels.append(det_label)
+ return det_bboxes, det_labels
+ else:
+ bbox_pred = None
+
+ return self.bbox_head.get_bboxes(
+ rois,
+ cls_score,
+ bbox_pred,
+ img_shapes,
+ scale_factors,
+ rescale=rescale,
+ cfg=rcnn_test_cfg)
+
+ def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
+ """Test det bboxes with test time augmentation."""
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ # TODO more flexible
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ rois = bbox2roi([proposals])
+ bbox_results = self._bbox_forward(x, rois)
+ bboxes, scores = self.bbox_head.get_bboxes(
+ rois,
+ bbox_results['cls_score'],
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ aug_bboxes.append(bboxes)
+ aug_scores.append(scores)
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+ return det_bboxes, det_labels
+
+
+class MaskTestMixin(object):
+
+ if sys.version_info >= (3, 7):
+
+ async def async_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False,
+ mask_test_cfg=None):
+ """Asynchronized test for mask head without augmentation."""
+ # image shape of the first image in the batch (only one)
+ ori_shape = img_metas[0]['ori_shape']
+ scale_factor = img_metas[0]['scale_factor']
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ if rescale and not isinstance(scale_factor,
+ (float, torch.Tensor)):
+ scale_factor = det_bboxes.new_tensor(scale_factor)
+ _bboxes = (
+ det_bboxes[:, :4] *
+ scale_factor if rescale else det_bboxes)
+ mask_rois = bbox2roi([_bboxes])
+ mask_feats = self.mask_roi_extractor(
+ x[:len(self.mask_roi_extractor.featmap_strides)],
+ mask_rois)
+
+ if self.with_shared_head:
+ mask_feats = self.shared_head(mask_feats)
+ if mask_test_cfg and mask_test_cfg.get('async_sleep_interval'):
+ sleep_interval = mask_test_cfg['async_sleep_interval']
+ else:
+ sleep_interval = 0.035
+ async with completed(
+ __name__,
+ 'mask_head_forward',
+ sleep_interval=sleep_interval):
+ mask_pred = self.mask_head(mask_feats)
+ segm_result = self.mask_head.get_seg_masks(
+ mask_pred, _bboxes, det_labels, self.test_cfg, ori_shape,
+ scale_factor, rescale)
+ return segm_result
+
+ def simple_test_mask(self,
+ x,
+ img_metas,
+ det_bboxes,
+ det_labels,
+ rescale=False):
+ """Simple test for mask head without augmentation."""
+ # image shapes of images in the batch
+ ori_shapes = tuple(meta['ori_shape'] for meta in img_metas)
+ scale_factors = tuple(meta['scale_factor'] for meta in img_metas)
+
+ # The length of proposals of different batches may be different.
+ # In order to form a batch, a padding operation is required.
+ if isinstance(det_bboxes, list):
+ # padding to form a batch
+ max_size = max([bboxes.size(0) for bboxes in det_bboxes])
+ for i, (bbox, label) in enumerate(zip(det_bboxes, det_labels)):
+ supplement_bbox = bbox.new_full(
+ (max_size - bbox.size(0), bbox.size(1)), 0)
+ supplement_label = label.new_full((max_size - label.size(0), ),
+ 0)
+ det_bboxes[i] = torch.cat((supplement_bbox, bbox), dim=0)
+ det_labels[i] = torch.cat((supplement_label, label), dim=0)
+ det_bboxes = torch.stack(det_bboxes, dim=0)
+ det_labels = torch.stack(det_labels, dim=0)
+
+ batch_size = det_bboxes.size(0)
+ num_proposals_per_img = det_bboxes.shape[1]
+
+ # if det_bboxes is rescaled to the original image size, we need to
+ # rescale it back to the testing scale to obtain RoIs.
+ det_bboxes = det_bboxes[..., :4]
+ if rescale:
+ if not isinstance(scale_factors[0], float):
+ scale_factors = det_bboxes.new_tensor(scale_factors)
+ det_bboxes = det_bboxes * scale_factors.unsqueeze(1)
+
+ batch_index = torch.arange(
+ det_bboxes.size(0), device=det_bboxes.device).float().view(
+ -1, 1, 1).expand(det_bboxes.size(0), det_bboxes.size(1), 1)
+ mask_rois = torch.cat([batch_index, det_bboxes], dim=-1)
+ mask_rois = mask_rois.view(-1, 5)
+ mask_results = self._mask_forward(x, mask_rois)
+ mask_pred = mask_results['mask_pred']
+ try:
+ mask_full_pred, mask_occ_pred = mask_pred
+ except:
+ mask_full_pred = mask_pred
+ mask_occ_pred = mask_pred
+
+
+ # Recover the batch dimension
+ mask_full_preds = mask_full_pred.reshape(batch_size, num_proposals_per_img,
+ *mask_full_pred.shape[1:])
+
+ mask_occ_preds = mask_occ_pred.reshape(batch_size, num_proposals_per_img,
+ *mask_occ_pred.shape[1:])
+
+
+ # apply mask post-processing to each image individually
+ segm_results = []
+ for i in range(batch_size):
+ mask_full_pred = mask_full_preds[i]
+ mask_occ_pred = mask_occ_preds[i]
+ det_bbox = det_bboxes[i]
+ det_label = det_labels[i]
+
+ # remove padding
+ supplement_mask = det_bbox[..., -1] != 0
+ mask_full_pred = mask_full_pred[supplement_mask]
+ mask_occ_pred = mask_occ_pred[supplement_mask]
+ det_bbox = det_bbox[supplement_mask]
+ det_label = det_label[supplement_mask]
+
+ if det_label.shape[0] == 0:
+ segm_results.append([[]
+ for _ in range(self.mask_head.num_classes)
+ ])
+ else:
+ segm_result_vis = self.mask_head.get_seg_masks(
+ mask_full_pred[:,0:1], det_bbox, det_label, self.test_cfg,
+ ori_shapes[i], scale_factors[i], rescale)
+
+ segm_result_occ = self.mask_head.get_seg_masks(
+ mask_occ_pred[:,0:1], det_bbox, det_label, self.test_cfg,
+ ori_shapes[i], scale_factors[i], rescale)
+
+ segm_result = segm_result_vis
+ segm_result[1] = segm_result_occ[0]
+
+ segm_results.append(segm_result)
+ return segm_results
+
+ def aug_test_mask(self, feats, img_metas, det_bboxes, det_labels):
+ """Test for mask head with test time augmentation."""
+ if det_bboxes.shape[0] == 0:
+ segm_result = [[] for _ in range(self.mask_head.num_classes)]
+ else:
+ aug_masks = []
+ for x, img_meta in zip(feats, img_metas):
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+ _bboxes = bbox_mapping(det_bboxes[:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ mask_rois = bbox2roi([_bboxes])
+ mask_results = self._mask_forward(x, mask_rois)
+ # convert to numpy array to save memory
+ aug_masks.append(
+ mask_results['mask_pred'].sigmoid().cpu().numpy())
+ merged_masks = merge_aug_masks(aug_masks, img_metas, self.test_cfg)
+
+ ori_shape = img_metas[0][0]['ori_shape']
+ segm_result = self.mask_head.get_seg_masks(
+ merged_masks,
+ det_bboxes,
+ det_labels,
+ self.test_cfg,
+ ori_shape,
+ scale_factor=1.0,
+ rescale=False)
+ return segm_result
diff --git a/mmdet/models/roi_heads/trident_roi_head.py b/mmdet/models/roi_heads/trident_roi_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..245569e50b45cc8e21ba8e7210edf4bd0c7f27c5
--- /dev/null
+++ b/mmdet/models/roi_heads/trident_roi_head.py
@@ -0,0 +1,119 @@
+import torch
+from mmcv.ops import batched_nms
+
+from mmdet.core import (bbox2result, bbox2roi, bbox_mapping, merge_aug_bboxes,
+ multiclass_nms)
+from mmdet.models.roi_heads.standard_roi_head import StandardRoIHead
+from ..builder import HEADS
+
+
+@HEADS.register_module()
+class TridentRoIHead(StandardRoIHead):
+ """Trident roi head.
+
+ Args:
+ num_branch (int): Number of branches in TridentNet.
+ test_branch_idx (int): In inference, all 3 branches will be used
+ if `test_branch_idx==-1`, otherwise only branch with index
+ `test_branch_idx` will be used.
+ """
+
+ def __init__(self, num_branch, test_branch_idx, **kwargs):
+ self.num_branch = num_branch
+ self.test_branch_idx = test_branch_idx
+ super(TridentRoIHead, self).__init__(**kwargs)
+
+ def merge_trident_bboxes(self, trident_det_bboxes, trident_det_labels):
+ """Merge bbox predictions of each branch."""
+ if trident_det_bboxes.numel() == 0:
+ det_bboxes = trident_det_bboxes.new_zeros((0, 5))
+ det_labels = trident_det_bboxes.new_zeros((0, ), dtype=torch.long)
+ else:
+ nms_bboxes = trident_det_bboxes[:, :4]
+ nms_scores = trident_det_bboxes[:, 4].contiguous()
+ nms_inds = trident_det_labels
+ nms_cfg = self.test_cfg['nms']
+ det_bboxes, keep = batched_nms(nms_bboxes, nms_scores, nms_inds,
+ nms_cfg)
+ det_labels = trident_det_labels[keep]
+ if self.test_cfg['max_per_img'] > 0:
+ det_labels = det_labels[:self.test_cfg['max_per_img']]
+ det_bboxes = det_bboxes[:self.test_cfg['max_per_img']]
+
+ return det_bboxes, det_labels
+
+ def simple_test(self,
+ x,
+ proposal_list,
+ img_metas,
+ proposals=None,
+ rescale=False):
+ """Test without augmentation as follows:
+
+ 1. Compute prediction bbox and label per branch.
+ 2. Merge predictions of each branch according to scores of
+ bboxes, i.e., bboxes with higher score are kept to give
+ top-k prediction.
+ """
+ assert self.with_bbox, 'Bbox head must be implemented.'
+ det_bboxes_list, det_labels_list = self.simple_test_bboxes(
+ x, img_metas, proposal_list, self.test_cfg, rescale=rescale)
+ num_branch = self.num_branch if self.test_branch_idx == -1 else 1
+ for _ in range(len(det_bboxes_list)):
+ if det_bboxes_list[_].shape[0] == 0:
+ det_bboxes_list[_] = det_bboxes_list[_].new_empty((0, 5))
+ det_bboxes, det_labels = [], []
+ for i in range(len(img_metas) // num_branch):
+ det_result = self.merge_trident_bboxes(
+ torch.cat(det_bboxes_list[i * num_branch:(i + 1) *
+ num_branch]),
+ torch.cat(det_labels_list[i * num_branch:(i + 1) *
+ num_branch]))
+ det_bboxes.append(det_result[0])
+ det_labels.append(det_result[1])
+
+ bbox_results = [
+ bbox2result(det_bboxes[i], det_labels[i],
+ self.bbox_head.num_classes)
+ for i in range(len(det_bboxes))
+ ]
+ return bbox_results
+
+ def aug_test_bboxes(self, feats, img_metas, proposal_list, rcnn_test_cfg):
+ """Test det bboxes with test time augmentation."""
+ aug_bboxes = []
+ aug_scores = []
+ for x, img_meta in zip(feats, img_metas):
+ # only one image in the batch
+ img_shape = img_meta[0]['img_shape']
+ scale_factor = img_meta[0]['scale_factor']
+ flip = img_meta[0]['flip']
+ flip_direction = img_meta[0]['flip_direction']
+
+ trident_bboxes, trident_scores = [], []
+ for branch_idx in range(len(proposal_list)):
+ proposals = bbox_mapping(proposal_list[0][:, :4], img_shape,
+ scale_factor, flip, flip_direction)
+ rois = bbox2roi([proposals])
+ bbox_results = self._bbox_forward(x, rois)
+ bboxes, scores = self.bbox_head.get_bboxes(
+ rois,
+ bbox_results['cls_score'],
+ bbox_results['bbox_pred'],
+ img_shape,
+ scale_factor,
+ rescale=False,
+ cfg=None)
+ trident_bboxes.append(bboxes)
+ trident_scores.append(scores)
+
+ aug_bboxes.append(torch.cat(trident_bboxes, 0))
+ aug_scores.append(torch.cat(trident_scores, 0))
+ # after merging, bboxes will be rescaled to the original image size
+ merged_bboxes, merged_scores = merge_aug_bboxes(
+ aug_bboxes, aug_scores, img_metas, rcnn_test_cfg)
+ det_bboxes, det_labels = multiclass_nms(merged_bboxes, merged_scores,
+ rcnn_test_cfg.score_thr,
+ rcnn_test_cfg.nms,
+ rcnn_test_cfg.max_per_img)
+ return det_bboxes, det_labels
diff --git a/mmdet/models/utils/__init__.py b/mmdet/models/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5165b22ce57d17f28392213e0f1b055c2b9360c1
--- /dev/null
+++ b/mmdet/models/utils/__init__.py
@@ -0,0 +1,16 @@
+from .builder import build_positional_encoding, build_transformer
+from .gaussian_target import gaussian_radius, gen_gaussian_target
+from .positional_encoding import (LearnedPositionalEncoding,
+ SinePositionalEncoding)
+from .res_layer import ResLayer, SimplifiedBasicBlock
+from .transformer import (FFN, DynamicConv, MultiheadAttention, Transformer,
+ TransformerDecoder, TransformerDecoderLayer,
+ TransformerEncoder, TransformerEncoderLayer)
+
+__all__ = [
+ 'ResLayer', 'gaussian_radius', 'gen_gaussian_target', 'MultiheadAttention',
+ 'FFN', 'TransformerEncoderLayer', 'TransformerEncoder',
+ 'TransformerDecoderLayer', 'TransformerDecoder', 'Transformer',
+ 'build_transformer', 'build_positional_encoding', 'SinePositionalEncoding',
+ 'LearnedPositionalEncoding', 'DynamicConv', 'SimplifiedBasicBlock'
+]
diff --git a/mmdet/models/utils/builder.py b/mmdet/models/utils/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..f362d1c92ca9d4ed95a2b3d28d3e6baedd14e462
--- /dev/null
+++ b/mmdet/models/utils/builder.py
@@ -0,0 +1,14 @@
+from mmcv.utils import Registry, build_from_cfg
+
+TRANSFORMER = Registry('Transformer')
+POSITIONAL_ENCODING = Registry('Position encoding')
+
+
+def build_transformer(cfg, default_args=None):
+ """Builder for Transformer."""
+ return build_from_cfg(cfg, TRANSFORMER, default_args)
+
+
+def build_positional_encoding(cfg, default_args=None):
+ """Builder for Position Encoding."""
+ return build_from_cfg(cfg, POSITIONAL_ENCODING, default_args)
diff --git a/mmdet/models/utils/gaussian_target.py b/mmdet/models/utils/gaussian_target.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bb7160cb4bf2f47876f6e8373142aa5846920a9
--- /dev/null
+++ b/mmdet/models/utils/gaussian_target.py
@@ -0,0 +1,185 @@
+from math import sqrt
+
+import torch
+
+
+def gaussian2D(radius, sigma=1, dtype=torch.float32, device='cpu'):
+ """Generate 2D gaussian kernel.
+
+ Args:
+ radius (int): Radius of gaussian kernel.
+ sigma (int): Sigma of gaussian function. Default: 1.
+ dtype (torch.dtype): Dtype of gaussian tensor. Default: torch.float32.
+ device (str): Device of gaussian tensor. Default: 'cpu'.
+
+ Returns:
+ h (Tensor): Gaussian kernel with a
+ ``(2 * radius + 1) * (2 * radius + 1)`` shape.
+ """
+ x = torch.arange(
+ -radius, radius + 1, dtype=dtype, device=device).view(1, -1)
+ y = torch.arange(
+ -radius, radius + 1, dtype=dtype, device=device).view(-1, 1)
+
+ h = (-(x * x + y * y) / (2 * sigma * sigma)).exp()
+
+ h[h < torch.finfo(h.dtype).eps * h.max()] = 0
+ return h
+
+
+def gen_gaussian_target(heatmap, center, radius, k=1):
+ """Generate 2D gaussian heatmap.
+
+ Args:
+ heatmap (Tensor): Input heatmap, the gaussian kernel will cover on
+ it and maintain the max value.
+ center (list[int]): Coord of gaussian kernel's center.
+ radius (int): Radius of gaussian kernel.
+ k (int): Coefficient of gaussian kernel. Default: 1.
+
+ Returns:
+ out_heatmap (Tensor): Updated heatmap covered by gaussian kernel.
+ """
+ diameter = 2 * radius + 1
+ gaussian_kernel = gaussian2D(
+ radius, sigma=diameter / 6, dtype=heatmap.dtype, device=heatmap.device)
+
+ x, y = center
+
+ height, width = heatmap.shape[:2]
+
+ left, right = min(x, radius), min(width - x, radius + 1)
+ top, bottom = min(y, radius), min(height - y, radius + 1)
+
+ masked_heatmap = heatmap[y - top:y + bottom, x - left:x + right]
+ masked_gaussian = gaussian_kernel[radius - top:radius + bottom,
+ radius - left:radius + right]
+ out_heatmap = heatmap
+ torch.max(
+ masked_heatmap,
+ masked_gaussian * k,
+ out=out_heatmap[y - top:y + bottom, x - left:x + right])
+
+ return out_heatmap
+
+
+def gaussian_radius(det_size, min_overlap):
+ r"""Generate 2D gaussian radius.
+
+ This function is modified from the `official github repo
+ `_.
+
+ Given ``min_overlap``, radius could computed by a quadratic equation
+ according to Vieta's formulas.
+
+ There are 3 cases for computing gaussian radius, details are following:
+
+ - Explanation of figure: ``lt`` and ``br`` indicates the left-top and
+ bottom-right corner of ground truth box. ``x`` indicates the
+ generated corner at the limited position when ``radius=r``.
+
+ - Case1: one corner is inside the gt box and the other is outside.
+
+ .. code:: text
+
+ |< width >|
+
+ lt-+----------+ -
+ | | | ^
+ +--x----------+--+
+ | | | |
+ | | | | height
+ | | overlap | |
+ | | | |
+ | | | | v
+ +--+---------br--+ -
+ | | |
+ +----------+--x
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{(w-r)*(h-r)}{w*h+(w+h)r-r^2} \ge {iou} \quad\Rightarrow\quad
+ {r^2-(w+h)r+\cfrac{1-iou}{1+iou}*w*h} \ge 0 \\
+ {a} = 1,\quad{b} = {-(w+h)},\quad{c} = {\cfrac{1-iou}{1+iou}*w*h}
+ {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
+
+ - Case2: both two corners are inside the gt box.
+
+ .. code:: text
+
+ |< width >|
+
+ lt-+----------+ -
+ | | | ^
+ +--x-------+ |
+ | | | |
+ | |overlap| | height
+ | | | |
+ | +-------x--+
+ | | | v
+ +----------+-br -
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{(w-2*r)*(h-2*r)}{w*h} \ge {iou} \quad\Rightarrow\quad
+ {4r^2-2(w+h)r+(1-iou)*w*h} \ge 0 \\
+ {a} = 4,\quad {b} = {-2(w+h)},\quad {c} = {(1-iou)*w*h}
+ {r} \le \cfrac{-b-\sqrt{b^2-4*a*c}}{2*a}
+
+ - Case3: both two corners are outside the gt box.
+
+ .. code:: text
+
+ |< width >|
+
+ x--+----------------+
+ | | |
+ +-lt-------------+ | -
+ | | | | ^
+ | | | |
+ | | overlap | | height
+ | | | |
+ | | | | v
+ | +------------br--+ -
+ | | |
+ +----------------+--x
+
+ To ensure IoU of generated box and gt box is larger than ``min_overlap``:
+
+ .. math::
+ \cfrac{w*h}{(w+2*r)*(h+2*r)} \ge {iou} \quad\Rightarrow\quad
+ {4*iou*r^2+2*iou*(w+h)r+(iou-1)*w*h} \le 0 \\
+ {a} = {4*iou},\quad {b} = {2*iou*(w+h)},\quad {c} = {(iou-1)*w*h} \\
+ {r} \le \cfrac{-b+\sqrt{b^2-4*a*c}}{2*a}
+
+ Args:
+ det_size (list[int]): Shape of object.
+ min_overlap (float): Min IoU with ground truth for boxes generated by
+ keypoints inside the gaussian kernel.
+
+ Returns:
+ radius (int): Radius of gaussian kernel.
+ """
+ height, width = det_size
+
+ a1 = 1
+ b1 = (height + width)
+ c1 = width * height * (1 - min_overlap) / (1 + min_overlap)
+ sq1 = sqrt(b1**2 - 4 * a1 * c1)
+ r1 = (b1 - sq1) / (2 * a1)
+
+ a2 = 4
+ b2 = 2 * (height + width)
+ c2 = (1 - min_overlap) * width * height
+ sq2 = sqrt(b2**2 - 4 * a2 * c2)
+ r2 = (b2 - sq2) / (2 * a2)
+
+ a3 = 4 * min_overlap
+ b3 = -2 * min_overlap * (height + width)
+ c3 = (min_overlap - 1) * width * height
+ sq3 = sqrt(b3**2 - 4 * a3 * c3)
+ r3 = (b3 + sq3) / (2 * a3)
+ return min(r1, r2, r3)
diff --git a/mmdet/models/utils/positional_encoding.py b/mmdet/models/utils/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bda2bbdbfcc28ba6304b6325ae556fa02554ac1
--- /dev/null
+++ b/mmdet/models/utils/positional_encoding.py
@@ -0,0 +1,150 @@
+import math
+
+import torch
+import torch.nn as nn
+from mmcv.cnn import uniform_init
+
+from .builder import POSITIONAL_ENCODING
+
+
+@POSITIONAL_ENCODING.register_module()
+class SinePositionalEncoding(nn.Module):
+ """Position encoding with sine and cosine functions.
+
+ See `End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. Note the final returned dimension
+ for each position is 2 times of this value.
+ temperature (int, optional): The temperature used for scaling
+ the position embedding. Default 10000.
+ normalize (bool, optional): Whether to normalize the position
+ embedding. Default False.
+ scale (float, optional): A scale factor that scales the position
+ embedding. The scale will be used only when `normalize` is True.
+ Default 2*pi.
+ eps (float, optional): A value added to the denominator for
+ numerical stability. Default 1e-6.
+ """
+
+ def __init__(self,
+ num_feats,
+ temperature=10000,
+ normalize=False,
+ scale=2 * math.pi,
+ eps=1e-6):
+ super(SinePositionalEncoding, self).__init__()
+ if normalize:
+ assert isinstance(scale, (float, int)), 'when normalize is set,' \
+ 'scale should be provided and in float or int type, ' \
+ f'found {type(scale)}'
+ self.num_feats = num_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ self.scale = scale
+ self.eps = eps
+
+ def forward(self, mask):
+ """Forward function for `SinePositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ y_embed = y_embed / (y_embed[:, -1:, :] + self.eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + self.eps) * self.scale
+ dim_t = torch.arange(
+ self.num_feats, dtype=torch.float32, device=mask.device)
+ dim_t = self.temperature**(2 * (dim_t // 2) / self.num_feats)
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()),
+ dim=4).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'temperature={self.temperature}, '
+ repr_str += f'normalize={self.normalize}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'eps={self.eps})'
+ return repr_str
+
+
+@POSITIONAL_ENCODING.register_module()
+class LearnedPositionalEncoding(nn.Module):
+ """Position embedding with learnable embedding weights.
+
+ Args:
+ num_feats (int): The feature dimension for each position
+ along x-axis or y-axis. The final returned dimension for
+ each position is 2 times of this value.
+ row_num_embed (int, optional): The dictionary size of row embeddings.
+ Default 50.
+ col_num_embed (int, optional): The dictionary size of col embeddings.
+ Default 50.
+ """
+
+ def __init__(self, num_feats, row_num_embed=50, col_num_embed=50):
+ super(LearnedPositionalEncoding, self).__init__()
+ self.row_embed = nn.Embedding(row_num_embed, num_feats)
+ self.col_embed = nn.Embedding(col_num_embed, num_feats)
+ self.num_feats = num_feats
+ self.row_num_embed = row_num_embed
+ self.col_num_embed = col_num_embed
+ self.init_weights()
+
+ def init_weights(self):
+ """Initialize the learnable weights."""
+ uniform_init(self.row_embed)
+ uniform_init(self.col_embed)
+
+ def forward(self, mask):
+ """Forward function for `LearnedPositionalEncoding`.
+
+ Args:
+ mask (Tensor): ByteTensor mask. Non-zero values representing
+ ignored positions, while zero values means valid positions
+ for this image. Shape [bs, h, w].
+
+ Returns:
+ pos (Tensor): Returned position embedding with shape
+ [bs, num_feats*2, h, w].
+ """
+ h, w = mask.shape[-2:]
+ x = torch.arange(w, device=mask.device)
+ y = torch.arange(h, device=mask.device)
+ x_embed = self.col_embed(x)
+ y_embed = self.row_embed(y)
+ pos = torch.cat(
+ (x_embed.unsqueeze(0).repeat(h, 1, 1), y_embed.unsqueeze(1).repeat(
+ 1, w, 1)),
+ dim=-1).permute(2, 0,
+ 1).unsqueeze(0).repeat(mask.shape[0], 1, 1, 1)
+ return pos
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_feats={self.num_feats}, '
+ repr_str += f'row_num_embed={self.row_num_embed}, '
+ repr_str += f'col_num_embed={self.col_num_embed})'
+ return repr_str
diff --git a/mmdet/models/utils/res_layer.py b/mmdet/models/utils/res_layer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a4efd3dd30b30123ed5135eac080ad9f7f7b448
--- /dev/null
+++ b/mmdet/models/utils/res_layer.py
@@ -0,0 +1,187 @@
+from mmcv.cnn import build_conv_layer, build_norm_layer
+from torch import nn as nn
+
+
+class ResLayer(nn.Sequential):
+ """ResLayer to build ResNet style backbone.
+
+ Args:
+ block (nn.Module): block used to build ResLayer.
+ inplanes (int): inplanes of block.
+ planes (int): planes of block.
+ num_blocks (int): number of blocks.
+ stride (int): stride of the first block. Default: 1
+ avg_down (bool): Use AvgPool instead of stride conv when
+ downsampling in the bottleneck. Default: False
+ conv_cfg (dict): dictionary to construct and config conv layer.
+ Default: None
+ norm_cfg (dict): dictionary to construct and config norm layer.
+ Default: dict(type='BN')
+ downsample_first (bool): Downsample at the first block or last block.
+ False for Hourglass, True for ResNet. Default: True
+ """
+
+ def __init__(self,
+ block,
+ inplanes,
+ planes,
+ num_blocks,
+ stride=1,
+ avg_down=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ downsample_first=True,
+ **kwargs):
+ self.block = block
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = []
+ conv_stride = stride
+ if avg_down:
+ conv_stride = 1
+ downsample.append(
+ nn.AvgPool2d(
+ kernel_size=stride,
+ stride=stride,
+ ceil_mode=True,
+ count_include_pad=False))
+ downsample.extend([
+ build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=conv_stride,
+ bias=False),
+ build_norm_layer(norm_cfg, planes * block.expansion)[1]
+ ])
+ downsample = nn.Sequential(*downsample)
+
+ layers = []
+ if downsample_first:
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ inplanes = planes * block.expansion
+ for _ in range(1, num_blocks):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+
+ else: # downsample_first=False is for HourglassModule
+ for _ in range(num_blocks - 1):
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=inplanes,
+ stride=1,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ layers.append(
+ block(
+ inplanes=inplanes,
+ planes=planes,
+ stride=stride,
+ downsample=downsample,
+ conv_cfg=conv_cfg,
+ norm_cfg=norm_cfg,
+ **kwargs))
+ super(ResLayer, self).__init__(*layers)
+
+
+class SimplifiedBasicBlock(nn.Module):
+ """Simplified version of original basic residual block. This is used in
+ `SCNet `_.
+
+ - Norm layer is now optional
+ - Last ReLU in forward function is removed
+ """
+ expansion = 1
+
+ def __init__(self,
+ inplanes,
+ planes,
+ stride=1,
+ dilation=1,
+ downsample=None,
+ style='pytorch',
+ with_cp=False,
+ conv_cfg=None,
+ norm_cfg=dict(type='BN'),
+ dcn=None,
+ plugins=None):
+ super(SimplifiedBasicBlock, self).__init__()
+ assert dcn is None, 'Not implemented yet.'
+ assert plugins is None, 'Not implemented yet.'
+ assert not with_cp, 'Not implemented yet.'
+ self.with_norm = norm_cfg is not None
+ with_bias = True if norm_cfg is None else False
+ self.conv1 = build_conv_layer(
+ conv_cfg,
+ inplanes,
+ planes,
+ 3,
+ stride=stride,
+ padding=dilation,
+ dilation=dilation,
+ bias=with_bias)
+ if self.with_norm:
+ self.norm1_name, norm1 = build_norm_layer(
+ norm_cfg, planes, postfix=1)
+ self.add_module(self.norm1_name, norm1)
+ self.conv2 = build_conv_layer(
+ conv_cfg, planes, planes, 3, padding=1, bias=with_bias)
+ if self.with_norm:
+ self.norm2_name, norm2 = build_norm_layer(
+ norm_cfg, planes, postfix=2)
+ self.add_module(self.norm2_name, norm2)
+
+ self.relu = nn.ReLU(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.with_cp = with_cp
+
+ @property
+ def norm1(self):
+ """nn.Module: normalization layer after the first convolution layer"""
+ return getattr(self, self.norm1_name) if self.with_norm else None
+
+ @property
+ def norm2(self):
+ """nn.Module: normalization layer after the second convolution layer"""
+ return getattr(self, self.norm2_name) if self.with_norm else None
+
+ def forward(self, x):
+ """Forward function."""
+
+ identity = x
+
+ out = self.conv1(x)
+ if self.with_norm:
+ out = self.norm1(out)
+ out = self.relu(out)
+
+ out = self.conv2(out)
+ if self.with_norm:
+ out = self.norm2(out)
+
+ if self.downsample is not None:
+ identity = self.downsample(x)
+
+ out += identity
+
+ return out
diff --git a/mmdet/models/utils/transformer.py b/mmdet/models/utils/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..83870eead42f4b0bf73c9e19248d5512d3d044c5
--- /dev/null
+++ b/mmdet/models/utils/transformer.py
@@ -0,0 +1,860 @@
+import torch
+import torch.nn as nn
+from mmcv.cnn import (Linear, build_activation_layer, build_norm_layer,
+ xavier_init)
+
+from .builder import TRANSFORMER
+
+
+class MultiheadAttention(nn.Module):
+ """A warpper for torch.nn.MultiheadAttention.
+
+ This module implements MultiheadAttention with residual connection,
+ and positional encoding used in DETR is also passed as input.
+
+ Args:
+ embed_dims (int): The embedding dimension.
+ num_heads (int): Parallel attention heads. Same as
+ `nn.MultiheadAttention`.
+ dropout (float): A Dropout layer on attn_output_weights. Default 0.0.
+ """
+
+ def __init__(self, embed_dims, num_heads, dropout=0.0):
+ super(MultiheadAttention, self).__init__()
+ assert embed_dims % num_heads == 0, 'embed_dims must be ' \
+ f'divisible by num_heads. got {embed_dims} and {num_heads}.'
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.attn = nn.MultiheadAttention(embed_dims, num_heads, dropout)
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(self,
+ x,
+ key=None,
+ value=None,
+ residual=None,
+ query_pos=None,
+ key_pos=None,
+ attn_mask=None,
+ key_padding_mask=None):
+ """Forward function for `MultiheadAttention`.
+
+ Args:
+ x (Tensor): The input query with shape [num_query, bs,
+ embed_dims]. Same in `nn.MultiheadAttention.forward`.
+ key (Tensor): The key tensor with shape [num_key, bs,
+ embed_dims]. Same in `nn.MultiheadAttention.forward`.
+ Default None. If None, the `query` will be used.
+ value (Tensor): The value tensor with same shape as `key`.
+ Same in `nn.MultiheadAttention.forward`. Default None.
+ If None, the `key` will be used.
+ residual (Tensor): The tensor used for addition, with the
+ same shape as `x`. Default None. If None, `x` will be used.
+ query_pos (Tensor): The positional encoding for query, with
+ the same shape as `x`. Default None. If not None, it will
+ be added to `x` before forward function.
+ key_pos (Tensor): The positional encoding for `key`, with the
+ same shape as `key`. Default None. If not None, it will
+ be added to `key` before forward function. If None, and
+ `query_pos` has the same shape as `key`, then `query_pos`
+ will be used for `key_pos`.
+ attn_mask (Tensor): ByteTensor mask with shape [num_query,
+ num_key]. Same in `nn.MultiheadAttention.forward`.
+ Default None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
+ Same in `nn.MultiheadAttention.forward`. Default None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ query = x
+ if key is None:
+ key = query
+ if value is None:
+ value = key
+ if residual is None:
+ residual = x
+ if key_pos is None:
+ if query_pos is not None and key is not None:
+ if query_pos.shape == key.shape:
+ key_pos = query_pos
+ if query_pos is not None:
+ query = query + query_pos
+ if key_pos is not None:
+ key = key + key_pos
+ out = self.attn(
+ query,
+ key,
+ value=value,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)[0]
+
+ return residual + self.dropout(out)
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'dropout={self.dropout})'
+ return repr_str
+
+
+class FFN(nn.Module):
+ """Implements feed-forward networks (FFNs) with residual connection.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `MultiheadAttention`.
+ feedforward_channels (int): The hidden dimension of FFNs.
+ num_fcs (int, optional): The number of fully-connected layers in
+ FFNs. Defaults to 2.
+ act_cfg (dict, optional): The activation config for FFNs.
+ dropout (float, optional): Probability of an element to be
+ zeroed. Default 0.0.
+ add_residual (bool, optional): Add resudual connection.
+ Defaults to True.
+ """
+
+ def __init__(self,
+ embed_dims,
+ feedforward_channels,
+ num_fcs=2,
+ act_cfg=dict(type='ReLU', inplace=True),
+ dropout=0.0,
+ add_residual=True):
+ super(FFN, self).__init__()
+ assert num_fcs >= 2, 'num_fcs should be no less ' \
+ f'than 2. got {num_fcs}.'
+ self.embed_dims = embed_dims
+ self.feedforward_channels = feedforward_channels
+ self.num_fcs = num_fcs
+ self.act_cfg = act_cfg
+ self.dropout = dropout
+ self.activate = build_activation_layer(act_cfg)
+
+ layers = nn.ModuleList()
+ in_channels = embed_dims
+ for _ in range(num_fcs - 1):
+ layers.append(
+ nn.Sequential(
+ Linear(in_channels, feedforward_channels), self.activate,
+ nn.Dropout(dropout)))
+ in_channels = feedforward_channels
+ layers.append(Linear(feedforward_channels, embed_dims))
+ self.layers = nn.Sequential(*layers)
+ self.dropout = nn.Dropout(dropout)
+ self.add_residual = add_residual
+
+ def forward(self, x, residual=None):
+ """Forward function for `FFN`."""
+ out = self.layers(x)
+ if not self.add_residual:
+ return out
+ if residual is None:
+ residual = x
+ return residual + self.dropout(out)
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(embed_dims={self.embed_dims}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'num_fcs={self.num_fcs}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'add_residual={self.add_residual})'
+ return repr_str
+
+
+class TransformerEncoderLayer(nn.Module):
+ """Implements one encoder layer in DETR transformer.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as `FFN`.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): The hidden dimension for FFNs.
+ dropout (float): Probability of an element to be zeroed. Default 0.0.
+ order (tuple[str]): The order for encoder layer. Valid examples are
+ ('selfattn', 'norm', 'ffn', 'norm') and ('norm', 'selfattn',
+ 'norm', 'ffn'). Default ('selfattn', 'norm', 'ffn', 'norm').
+ act_cfg (dict): The activation config for FFNs. Default ReLU.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ num_fcs (int): The number of fully-connected layers for FFNs.
+ Default 2.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ dropout=0.0,
+ order=('selfattn', 'norm', 'ffn', 'norm'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2):
+ super(TransformerEncoderLayer, self).__init__()
+ assert isinstance(order, tuple) and len(order) == 4
+ assert set(order) == set(['selfattn', 'norm', 'ffn'])
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.feedforward_channels = feedforward_channels
+ self.dropout = dropout
+ self.order = order
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.num_fcs = num_fcs
+ self.pre_norm = order[0] == 'norm'
+ self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout)
+ self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
+ dropout)
+ self.norms = nn.ModuleList()
+ self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
+ self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
+
+ def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None):
+ """Forward function for `TransformerEncoderLayer`.
+
+ Args:
+ x (Tensor): The input query with shape [num_key, bs,
+ embed_dims]. Same in `MultiheadAttention.forward`.
+ pos (Tensor): The positional encoding for query. Default None.
+ Same as `query_pos` in `MultiheadAttention.forward`.
+ attn_mask (Tensor): ByteTensor mask with shape [num_key,
+ num_key]. Same in `MultiheadAttention.forward`. Default None.
+ key_padding_mask (Tensor): ByteTensor with shape [bs, num_key].
+ Same in `MultiheadAttention.forward`. Default None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_key, bs, embed_dims].
+ """
+ norm_cnt = 0
+ inp_residual = x
+ for layer in self.order:
+ if layer == 'selfattn':
+ # self attention
+ query = key = value = x
+ x = self.self_attn(
+ query,
+ key,
+ value,
+ inp_residual if self.pre_norm else None,
+ query_pos=pos,
+ key_pos=pos,
+ attn_mask=attn_mask,
+ key_padding_mask=key_padding_mask)
+ inp_residual = x
+ elif layer == 'norm':
+ x = self.norms[norm_cnt](x)
+ norm_cnt += 1
+ elif layer == 'ffn':
+ x = self.ffn(x, inp_residual if self.pre_norm else None)
+ return x
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'order={self.order}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg}, '
+ repr_str += f'num_fcs={self.num_fcs})'
+ return repr_str
+
+
+class TransformerDecoderLayer(nn.Module):
+ """Implements one decoder layer in DETR transformer.
+
+ Args:
+ embed_dims (int): The feature dimension. Same as
+ `TransformerEncoderLayer`.
+ num_heads (int): Parallel attention heads.
+ feedforward_channels (int): Same as `TransformerEncoderLayer`.
+ dropout (float): Same as `TransformerEncoderLayer`. Default 0.0.
+ order (tuple[str]): The order for decoder layer. Valid examples are
+ ('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn', 'norm') and
+ ('norm', 'selfattn', 'norm', 'multiheadattn', 'norm', 'ffn').
+ Default the former.
+ act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ num_fcs (int): The number of fully-connected layers in FFNs.
+ """
+
+ def __init__(self,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ dropout=0.0,
+ order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn',
+ 'norm'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2):
+ super(TransformerDecoderLayer, self).__init__()
+ assert isinstance(order, tuple) and len(order) == 6
+ assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn'])
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.feedforward_channels = feedforward_channels
+ self.dropout = dropout
+ self.order = order
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.num_fcs = num_fcs
+ self.pre_norm = order[0] == 'norm'
+ self.self_attn = MultiheadAttention(embed_dims, num_heads, dropout)
+ self.multihead_attn = MultiheadAttention(embed_dims, num_heads,
+ dropout)
+ self.ffn = FFN(embed_dims, feedforward_channels, num_fcs, act_cfg,
+ dropout)
+ self.norms = nn.ModuleList()
+ # 3 norm layers in official DETR's TransformerDecoderLayer
+ for _ in range(3):
+ self.norms.append(build_norm_layer(norm_cfg, embed_dims)[1])
+
+ def forward(self,
+ x,
+ memory,
+ memory_pos=None,
+ query_pos=None,
+ memory_attn_mask=None,
+ target_attn_mask=None,
+ memory_key_padding_mask=None,
+ target_key_padding_mask=None):
+ """Forward function for `TransformerDecoderLayer`.
+
+ Args:
+ x (Tensor): Input query with shape [num_query, bs, embed_dims].
+ memory (Tensor): Tensor got from `TransformerEncoder`, with shape
+ [num_key, bs, embed_dims].
+ memory_pos (Tensor): The positional encoding for `memory`. Default
+ None. Same as `key_pos` in `MultiheadAttention.forward`.
+ query_pos (Tensor): The positional encoding for `query`. Default
+ None. Same as `query_pos` in `MultiheadAttention.forward`.
+ memory_attn_mask (Tensor): ByteTensor mask for `memory`, with
+ shape [num_key, num_key]. Same as `attn_mask` in
+ `MultiheadAttention.forward`. Default None.
+ target_attn_mask (Tensor): ByteTensor mask for `x`, with shape
+ [num_query, num_query]. Same as `attn_mask` in
+ `MultiheadAttention.forward`. Default None.
+ memory_key_padding_mask (Tensor): ByteTensor for `memory`, with
+ shape [bs, num_key]. Same as `key_padding_mask` in
+ `MultiheadAttention.forward`. Default None.
+ target_key_padding_mask (Tensor): ByteTensor for `x`, with shape
+ [bs, num_query]. Same as `key_padding_mask` in
+ `MultiheadAttention.forward`. Default None.
+
+ Returns:
+ Tensor: forwarded results with shape [num_query, bs, embed_dims].
+ """
+ norm_cnt = 0
+ inp_residual = x
+ for layer in self.order:
+ if layer == 'selfattn':
+ query = key = value = x
+ x = self.self_attn(
+ query,
+ key,
+ value,
+ inp_residual if self.pre_norm else None,
+ query_pos,
+ key_pos=query_pos,
+ attn_mask=target_attn_mask,
+ key_padding_mask=target_key_padding_mask)
+ inp_residual = x
+ elif layer == 'norm':
+ x = self.norms[norm_cnt](x)
+ norm_cnt += 1
+ elif layer == 'multiheadattn':
+ query = x
+ key = value = memory
+ x = self.multihead_attn(
+ query,
+ key,
+ value,
+ inp_residual if self.pre_norm else None,
+ query_pos,
+ key_pos=memory_pos,
+ attn_mask=memory_attn_mask,
+ key_padding_mask=memory_key_padding_mask)
+ inp_residual = x
+ elif layer == 'ffn':
+ x = self.ffn(x, inp_residual if self.pre_norm else None)
+ return x
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'order={self.order}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg}, '
+ repr_str += f'num_fcs={self.num_fcs})'
+ return repr_str
+
+
+class TransformerEncoder(nn.Module):
+ """Implements the encoder in DETR transformer.
+
+ Args:
+ num_layers (int): The number of `TransformerEncoderLayer`.
+ embed_dims (int): Same as `TransformerEncoderLayer`.
+ num_heads (int): Same as `TransformerEncoderLayer`.
+ feedforward_channels (int): Same as `TransformerEncoderLayer`.
+ dropout (float): Same as `TransformerEncoderLayer`. Default 0.0.
+ order (tuple[str]): Same as `TransformerEncoderLayer`.
+ act_cfg (dict): Same as `TransformerEncoderLayer`. Default ReLU.
+ norm_cfg (dict): Same as `TransformerEncoderLayer`. Default
+ layer normalization.
+ num_fcs (int): Same as `TransformerEncoderLayer`. Default 2.
+ """
+
+ def __init__(self,
+ num_layers,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ dropout=0.0,
+ order=('selfattn', 'norm', 'ffn', 'norm'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2):
+ super(TransformerEncoder, self).__init__()
+ assert isinstance(order, tuple) and len(order) == 4
+ assert set(order) == set(['selfattn', 'norm', 'ffn'])
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.feedforward_channels = feedforward_channels
+ self.dropout = dropout
+ self.order = order
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.num_fcs = num_fcs
+ self.pre_norm = order[0] == 'norm'
+ self.layers = nn.ModuleList()
+ for _ in range(num_layers):
+ self.layers.append(
+ TransformerEncoderLayer(embed_dims, num_heads,
+ feedforward_channels, dropout, order,
+ act_cfg, norm_cfg, num_fcs))
+ self.norm = build_norm_layer(norm_cfg,
+ embed_dims)[1] if self.pre_norm else None
+
+ def forward(self, x, pos=None, attn_mask=None, key_padding_mask=None):
+ """Forward function for `TransformerEncoder`.
+
+ Args:
+ x (Tensor): Input query. Same in `TransformerEncoderLayer.forward`.
+ pos (Tensor): Positional encoding for query. Default None.
+ Same in `TransformerEncoderLayer.forward`.
+ attn_mask (Tensor): ByteTensor attention mask. Default None.
+ Same in `TransformerEncoderLayer.forward`.
+ key_padding_mask (Tensor): Same in
+ `TransformerEncoderLayer.forward`. Default None.
+
+ Returns:
+ Tensor: Results with shape [num_key, bs, embed_dims].
+ """
+ for layer in self.layers:
+ x = layer(x, pos, attn_mask, key_padding_mask)
+ if self.norm is not None:
+ x = self.norm(x)
+ return x
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_layers={self.num_layers}, '
+ repr_str += f'embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'order={self.order}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg}, '
+ repr_str += f'num_fcs={self.num_fcs})'
+ return repr_str
+
+
+class TransformerDecoder(nn.Module):
+ """Implements the decoder in DETR transformer.
+
+ Args:
+ num_layers (int): The number of `TransformerDecoderLayer`.
+ embed_dims (int): Same as `TransformerDecoderLayer`.
+ num_heads (int): Same as `TransformerDecoderLayer`.
+ feedforward_channels (int): Same as `TransformerDecoderLayer`.
+ dropout (float): Same as `TransformerDecoderLayer`. Default 0.0.
+ order (tuple[str]): Same as `TransformerDecoderLayer`.
+ act_cfg (dict): Same as `TransformerDecoderLayer`. Default ReLU.
+ norm_cfg (dict): Same as `TransformerDecoderLayer`. Default
+ layer normalization.
+ num_fcs (int): Same as `TransformerDecoderLayer`. Default 2.
+ """
+
+ def __init__(self,
+ num_layers,
+ embed_dims,
+ num_heads,
+ feedforward_channels,
+ dropout=0.0,
+ order=('selfattn', 'norm', 'multiheadattn', 'norm', 'ffn',
+ 'norm'),
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2,
+ return_intermediate=False):
+ super(TransformerDecoder, self).__init__()
+ assert isinstance(order, tuple) and len(order) == 6
+ assert set(order) == set(['selfattn', 'norm', 'multiheadattn', 'ffn'])
+ self.num_layers = num_layers
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.feedforward_channels = feedforward_channels
+ self.dropout = dropout
+ self.order = order
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.num_fcs = num_fcs
+ self.return_intermediate = return_intermediate
+ self.layers = nn.ModuleList()
+ for _ in range(num_layers):
+ self.layers.append(
+ TransformerDecoderLayer(embed_dims, num_heads,
+ feedforward_channels, dropout, order,
+ act_cfg, norm_cfg, num_fcs))
+ self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
+
+ def forward(self,
+ x,
+ memory,
+ memory_pos=None,
+ query_pos=None,
+ memory_attn_mask=None,
+ target_attn_mask=None,
+ memory_key_padding_mask=None,
+ target_key_padding_mask=None):
+ """Forward function for `TransformerDecoder`.
+
+ Args:
+ x (Tensor): Input query. Same in `TransformerDecoderLayer.forward`.
+ memory (Tensor): Same in `TransformerDecoderLayer.forward`.
+ memory_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
+ Default None.
+ query_pos (Tensor): Same in `TransformerDecoderLayer.forward`.
+ Default None.
+ memory_attn_mask (Tensor): Same in
+ `TransformerDecoderLayer.forward`. Default None.
+ target_attn_mask (Tensor): Same in
+ `TransformerDecoderLayer.forward`. Default None.
+ memory_key_padding_mask (Tensor): Same in
+ `TransformerDecoderLayer.forward`. Default None.
+ target_key_padding_mask (Tensor): Same in
+ `TransformerDecoderLayer.forward`. Default None.
+
+ Returns:
+ Tensor: Results with shape [num_query, bs, embed_dims].
+ """
+ intermediate = []
+ for layer in self.layers:
+ x = layer(x, memory, memory_pos, query_pos, memory_attn_mask,
+ target_attn_mask, memory_key_padding_mask,
+ target_key_padding_mask)
+ if self.return_intermediate:
+ intermediate.append(self.norm(x))
+ if self.norm is not None:
+ x = self.norm(x)
+ if self.return_intermediate:
+ intermediate.pop()
+ intermediate.append(x)
+ if self.return_intermediate:
+ return torch.stack(intermediate)
+ return x.unsqueeze(0)
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(num_layers={self.num_layers}, '
+ repr_str += f'embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'order={self.order}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg}, '
+ repr_str += f'num_fcs={self.num_fcs}, '
+ repr_str += f'return_intermediate={self.return_intermediate})'
+ return repr_str
+
+
+@TRANSFORMER.register_module()
+class Transformer(nn.Module):
+ """Implements the DETR transformer.
+
+ Following the official DETR implementation, this module copy-paste
+ from torch.nn.Transformer with modifications:
+
+ * positional encodings are passed in MultiheadAttention
+ * extra LN at the end of encoder is removed
+ * decoder returns a stack of activations from all decoding layers
+
+ See `paper: End-to-End Object Detection with Transformers
+ `_ for details.
+
+ Args:
+ embed_dims (int): The feature dimension.
+ num_heads (int): Parallel attention heads. Same as
+ `nn.MultiheadAttention`.
+ num_encoder_layers (int): Number of `TransformerEncoderLayer`.
+ num_decoder_layers (int): Number of `TransformerDecoderLayer`.
+ feedforward_channels (int): The hidden dimension for FFNs used in both
+ encoder and decoder.
+ dropout (float): Probability of an element to be zeroed. Default 0.0.
+ act_cfg (dict): Activation config for FFNs used in both encoder
+ and decoder. Default ReLU.
+ norm_cfg (dict): Config dict for normalization used in both encoder
+ and decoder. Default layer normalization.
+ num_fcs (int): The number of fully-connected layers in FFNs, which is
+ used for both encoder and decoder.
+ pre_norm (bool): Whether the normalization layer is ordered
+ first in the encoder and decoder. Default False.
+ return_intermediate_dec (bool): Whether to return the intermediate
+ output from each TransformerDecoderLayer or only the last
+ TransformerDecoderLayer. Default False. If False, the returned
+ `hs` has shape [num_decoder_layers, bs, num_query, embed_dims].
+ If True, the returned `hs` will have shape [1, bs, num_query,
+ embed_dims].
+ """
+
+ def __init__(self,
+ embed_dims=512,
+ num_heads=8,
+ num_encoder_layers=6,
+ num_decoder_layers=6,
+ feedforward_channels=2048,
+ dropout=0.0,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN'),
+ num_fcs=2,
+ pre_norm=False,
+ return_intermediate_dec=False):
+ super(Transformer, self).__init__()
+ self.embed_dims = embed_dims
+ self.num_heads = num_heads
+ self.num_encoder_layers = num_encoder_layers
+ self.num_decoder_layers = num_decoder_layers
+ self.feedforward_channels = feedforward_channels
+ self.dropout = dropout
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.num_fcs = num_fcs
+ self.pre_norm = pre_norm
+ self.return_intermediate_dec = return_intermediate_dec
+ if self.pre_norm:
+ encoder_order = ('norm', 'selfattn', 'norm', 'ffn')
+ decoder_order = ('norm', 'selfattn', 'norm', 'multiheadattn',
+ 'norm', 'ffn')
+ else:
+ encoder_order = ('selfattn', 'norm', 'ffn', 'norm')
+ decoder_order = ('selfattn', 'norm', 'multiheadattn', 'norm',
+ 'ffn', 'norm')
+ self.encoder = TransformerEncoder(num_encoder_layers, embed_dims,
+ num_heads, feedforward_channels,
+ dropout, encoder_order, act_cfg,
+ norm_cfg, num_fcs)
+ self.decoder = TransformerDecoder(num_decoder_layers, embed_dims,
+ num_heads, feedforward_channels,
+ dropout, decoder_order, act_cfg,
+ norm_cfg, num_fcs,
+ return_intermediate_dec)
+
+ def init_weights(self, distribution='uniform'):
+ """Initialize the transformer weights."""
+ # follow the official DETR to init parameters
+ for m in self.modules():
+ if hasattr(m, 'weight') and m.weight.dim() > 1:
+ xavier_init(m, distribution=distribution)
+
+ def forward(self, x, mask, query_embed, pos_embed):
+ """Forward function for `Transformer`.
+
+ Args:
+ x (Tensor): Input query with shape [bs, c, h, w] where
+ c = embed_dims.
+ mask (Tensor): The key_padding_mask used for encoder and decoder,
+ with shape [bs, h, w].
+ query_embed (Tensor): The query embedding for decoder, with shape
+ [num_query, c].
+ pos_embed (Tensor): The positional encoding for encoder and
+ decoder, with the same shape as `x`.
+
+ Returns:
+ tuple[Tensor]: results of decoder containing the following tensor.
+
+ - out_dec: Output from decoder. If return_intermediate_dec \
+ is True output has shape [num_dec_layers, bs,
+ num_query, embed_dims], else has shape [1, bs, \
+ num_query, embed_dims].
+ - memory: Output results from encoder, with shape \
+ [bs, embed_dims, h, w].
+ """
+ bs, c, h, w = x.shape
+ x = x.flatten(2).permute(2, 0, 1) # [bs, c, h, w] -> [h*w, bs, c]
+ pos_embed = pos_embed.flatten(2).permute(2, 0, 1)
+ query_embed = query_embed.unsqueeze(1).repeat(
+ 1, bs, 1) # [num_query, dim] -> [num_query, bs, dim]
+ mask = mask.flatten(1) # [bs, h, w] -> [bs, h*w]
+ memory = self.encoder(
+ x, pos=pos_embed, attn_mask=None, key_padding_mask=mask)
+ target = torch.zeros_like(query_embed)
+ # out_dec: [num_layers, num_query, bs, dim]
+ out_dec = self.decoder(
+ target,
+ memory,
+ memory_pos=pos_embed,
+ query_pos=query_embed,
+ memory_attn_mask=None,
+ target_attn_mask=None,
+ memory_key_padding_mask=mask,
+ target_key_padding_mask=None)
+ out_dec = out_dec.transpose(1, 2)
+ memory = memory.permute(1, 2, 0).reshape(bs, c, h, w)
+ return out_dec, memory
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(embed_dims={self.embed_dims}, '
+ repr_str += f'num_heads={self.num_heads}, '
+ repr_str += f'num_encoder_layers={self.num_encoder_layers}, '
+ repr_str += f'num_decoder_layers={self.num_decoder_layers}, '
+ repr_str += f'feedforward_channels={self.feedforward_channels}, '
+ repr_str += f'dropout={self.dropout}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg}, '
+ repr_str += f'num_fcs={self.num_fcs}, '
+ repr_str += f'pre_norm={self.pre_norm}, '
+ repr_str += f'return_intermediate_dec={self.return_intermediate_dec})'
+ return repr_str
+
+
+@TRANSFORMER.register_module()
+class DynamicConv(nn.Module):
+ """Implements Dynamic Convolution.
+
+ This module generate parameters for each sample and
+ use bmm to implement 1*1 convolution. Code is modified
+ from the `official github repo `_ .
+
+ Args:
+ in_channels (int): The input feature channel.
+ Defaults to 256.
+ feat_channels (int): The inner feature channel.
+ Defaults to 64.
+ out_channels (int, optional): The output feature channel.
+ When not specified, it will be set to `in_channels`
+ by default
+ input_feat_shape (int): The shape of input feature.
+ Defaults to 7.
+ act_cfg (dict): The activation config for DynamicConv.
+ norm_cfg (dict): Config dict for normalization layer. Default
+ layer normalization.
+ """
+
+ def __init__(self,
+ in_channels=256,
+ feat_channels=64,
+ out_channels=None,
+ input_feat_shape=7,
+ act_cfg=dict(type='ReLU', inplace=True),
+ norm_cfg=dict(type='LN')):
+ super(DynamicConv, self).__init__()
+ self.in_channels = in_channels
+ self.feat_channels = feat_channels
+ self.out_channels_raw = out_channels
+ self.input_feat_shape = input_feat_shape
+ self.act_cfg = act_cfg
+ self.norm_cfg = norm_cfg
+ self.out_channels = out_channels if out_channels else in_channels
+
+ self.num_params_in = self.in_channels * self.feat_channels
+ self.num_params_out = self.out_channels * self.feat_channels
+ self.dynamic_layer = nn.Linear(
+ self.in_channels, self.num_params_in + self.num_params_out)
+
+ self.norm_in = build_norm_layer(norm_cfg, self.feat_channels)[1]
+ self.norm_out = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ self.activation = build_activation_layer(act_cfg)
+
+ num_output = self.out_channels * input_feat_shape**2
+ self.fc_layer = nn.Linear(num_output, self.out_channels)
+ self.fc_norm = build_norm_layer(norm_cfg, self.out_channels)[1]
+
+ def forward(self, param_feature, input_feature):
+ """Forward function for `DynamicConv`.
+
+ Args:
+ param_feature (Tensor): The feature can be used
+ to generate the parameter, has shape
+ (num_all_proposals, in_channels).
+ input_feature (Tensor): Feature that
+ interact with parameters, has shape
+ (num_all_proposals, in_channels, H, W).
+
+ Returns:
+ Tensor: The output feature has shape
+ (num_all_proposals, out_channels).
+ """
+ num_proposals = param_feature.size(0)
+ input_feature = input_feature.view(num_proposals, self.in_channels,
+ -1).permute(2, 0, 1)
+
+ input_feature = input_feature.permute(1, 0, 2)
+ parameters = self.dynamic_layer(param_feature)
+
+ param_in = parameters[:, :self.num_params_in].view(
+ -1, self.in_channels, self.feat_channels)
+ param_out = parameters[:, -self.num_params_out:].view(
+ -1, self.feat_channels, self.out_channels)
+
+ # input_feature has shape (num_all_proposals, H*W, in_channels)
+ # param_in has shape (num_all_proposals, in_channels, feat_channels)
+ # feature has shape (num_all_proposals, H*W, feat_channels)
+ features = torch.bmm(input_feature, param_in)
+ features = self.norm_in(features)
+ features = self.activation(features)
+
+ # param_out has shape (batch_size, feat_channels, out_channels)
+ features = torch.bmm(features, param_out)
+ features = self.norm_out(features)
+ features = self.activation(features)
+
+ features = features.flatten(1)
+ features = self.fc_layer(features)
+ features = self.fc_norm(features)
+ features = self.activation(features)
+
+ return features
+
+ def __repr__(self):
+ """str: a string that describes the module"""
+ repr_str = self.__class__.__name__
+ repr_str += f'(in_channels={self.in_channels}, '
+ repr_str += f'feat_channels={self.feat_channels}, '
+ repr_str += f'out_channels={self.out_channels_raw}, '
+ repr_str += f'input_feat_shape={self.input_feat_shape}, '
+ repr_str += f'act_cfg={self.act_cfg}, '
+ repr_str += f'norm_cfg={self.norm_cfg})'
+ return repr_str
diff --git a/mmdet/utils/__init__.py b/mmdet/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e79ad8c02a2d465f0690a4aa80683a5c6d784d52
--- /dev/null
+++ b/mmdet/utils/__init__.py
@@ -0,0 +1,5 @@
+from .collect_env import collect_env
+from .logger import get_root_logger
+from .optimizer import DistOptimizerHook
+
+__all__ = ['get_root_logger', 'collect_env', 'DistOptimizerHook']
diff --git a/mmdet/utils/collect_env.py b/mmdet/utils/collect_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..89c064accdb10abec4a03de04f601d27aab2da70
--- /dev/null
+++ b/mmdet/utils/collect_env.py
@@ -0,0 +1,16 @@
+from mmcv.utils import collect_env as collect_base_env
+from mmcv.utils import get_git_hash
+
+import mmdet
+
+
+def collect_env():
+ """Collect the information of the running environments."""
+ env_info = collect_base_env()
+ env_info['MMDetection'] = mmdet.__version__ + '+' + get_git_hash()[:7]
+ return env_info
+
+
+if __name__ == '__main__':
+ for name, val in collect_env().items():
+ print(f'{name}: {val}')
diff --git a/mmdet/utils/contextmanagers.py b/mmdet/utils/contextmanagers.py
new file mode 100644
index 0000000000000000000000000000000000000000..38a639262d949b5754dedf12f33fa814b030ea38
--- /dev/null
+++ b/mmdet/utils/contextmanagers.py
@@ -0,0 +1,121 @@
+import asyncio
+import contextlib
+import logging
+import os
+import time
+from typing import List
+
+import torch
+
+logger = logging.getLogger(__name__)
+
+DEBUG_COMPLETED_TIME = bool(os.environ.get('DEBUG_COMPLETED_TIME', False))
+
+
+@contextlib.asynccontextmanager
+async def completed(trace_name='',
+ name='',
+ sleep_interval=0.05,
+ streams: List[torch.cuda.Stream] = None):
+ """Async context manager that waits for work to complete on given CUDA
+ streams."""
+ if not torch.cuda.is_available():
+ yield
+ return
+
+ stream_before_context_switch = torch.cuda.current_stream()
+ if not streams:
+ streams = [stream_before_context_switch]
+ else:
+ streams = [s if s else stream_before_context_switch for s in streams]
+
+ end_events = [
+ torch.cuda.Event(enable_timing=DEBUG_COMPLETED_TIME) for _ in streams
+ ]
+
+ if DEBUG_COMPLETED_TIME:
+ start = torch.cuda.Event(enable_timing=True)
+ stream_before_context_switch.record_event(start)
+
+ cpu_start = time.monotonic()
+ logger.debug('%s %s starting, streams: %s', trace_name, name, streams)
+ grad_enabled_before = torch.is_grad_enabled()
+ try:
+ yield
+ finally:
+ current_stream = torch.cuda.current_stream()
+ assert current_stream == stream_before_context_switch
+
+ if DEBUG_COMPLETED_TIME:
+ cpu_end = time.monotonic()
+ for i, stream in enumerate(streams):
+ event = end_events[i]
+ stream.record_event(event)
+
+ grad_enabled_after = torch.is_grad_enabled()
+
+ # observed change of torch.is_grad_enabled() during concurrent run of
+ # async_test_bboxes code
+ assert (grad_enabled_before == grad_enabled_after
+ ), 'Unexpected is_grad_enabled() value change'
+
+ are_done = [e.query() for e in end_events]
+ logger.debug('%s %s completed: %s streams: %s', trace_name, name,
+ are_done, streams)
+ with torch.cuda.stream(stream_before_context_switch):
+ while not all(are_done):
+ await asyncio.sleep(sleep_interval)
+ are_done = [e.query() for e in end_events]
+ logger.debug(
+ '%s %s completed: %s streams: %s',
+ trace_name,
+ name,
+ are_done,
+ streams,
+ )
+
+ current_stream = torch.cuda.current_stream()
+ assert current_stream == stream_before_context_switch
+
+ if DEBUG_COMPLETED_TIME:
+ cpu_time = (cpu_end - cpu_start) * 1000
+ stream_times_ms = ''
+ for i, stream in enumerate(streams):
+ elapsed_time = start.elapsed_time(end_events[i])
+ stream_times_ms += f' {stream} {elapsed_time:.2f} ms'
+ logger.info('%s %s %.2f ms %s', trace_name, name, cpu_time,
+ stream_times_ms)
+
+
+@contextlib.asynccontextmanager
+async def concurrent(streamqueue: asyncio.Queue,
+ trace_name='concurrent',
+ name='stream'):
+ """Run code concurrently in different streams.
+
+ :param streamqueue: asyncio.Queue instance.
+
+ Queue tasks define the pool of streams used for concurrent execution.
+ """
+ if not torch.cuda.is_available():
+ yield
+ return
+
+ initial_stream = torch.cuda.current_stream()
+
+ with torch.cuda.stream(initial_stream):
+ stream = await streamqueue.get()
+ assert isinstance(stream, torch.cuda.Stream)
+
+ try:
+ with torch.cuda.stream(stream):
+ logger.debug('%s %s is starting, stream: %s', trace_name, name,
+ stream)
+ yield
+ current = torch.cuda.current_stream()
+ assert current == stream
+ logger.debug('%s %s has finished, stream: %s', trace_name,
+ name, stream)
+ finally:
+ streamqueue.task_done()
+ streamqueue.put_nowait(stream)
diff --git a/mmdet/utils/logger.py b/mmdet/utils/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fc6e6b438a73e857ba6f173594985807cb88b30
--- /dev/null
+++ b/mmdet/utils/logger.py
@@ -0,0 +1,19 @@
+import logging
+
+from mmcv.utils import get_logger
+
+
+def get_root_logger(log_file=None, log_level=logging.INFO):
+ """Get root logger.
+
+ Args:
+ log_file (str, optional): File path of log. Defaults to None.
+ log_level (int, optional): The level of logger.
+ Defaults to logging.INFO.
+
+ Returns:
+ :obj:`logging.Logger`: The obtained logger
+ """
+ logger = get_logger(name='mmdet', log_file=log_file, log_level=log_level)
+
+ return logger
diff --git a/mmdet/utils/optimizer.py b/mmdet/utils/optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c9d11941c0b43d42bd6daad1e4b927eaca3e675
--- /dev/null
+++ b/mmdet/utils/optimizer.py
@@ -0,0 +1,33 @@
+from mmcv.runner import OptimizerHook, HOOKS
+try:
+ import apex
+except:
+ print('apex is not installed')
+
+
+@HOOKS.register_module()
+class DistOptimizerHook(OptimizerHook):
+ """Optimizer hook for distributed training."""
+
+ def __init__(self, update_interval=1, grad_clip=None, coalesce=True, bucket_size_mb=-1, use_fp16=False):
+ self.grad_clip = grad_clip
+ self.coalesce = coalesce
+ self.bucket_size_mb = bucket_size_mb
+ self.update_interval = update_interval
+ self.use_fp16 = use_fp16
+
+ def before_run(self, runner):
+ runner.optimizer.zero_grad()
+
+ def after_train_iter(self, runner):
+ runner.outputs['loss'] /= self.update_interval
+ if self.use_fp16:
+ with apex.amp.scale_loss(runner.outputs['loss'], runner.optimizer) as scaled_loss:
+ scaled_loss.backward()
+ else:
+ runner.outputs['loss'].backward()
+ if self.every_n_iters(runner, self.update_interval):
+ if self.grad_clip is not None:
+ self.clip_grads(runner.model.parameters())
+ runner.optimizer.step()
+ runner.optimizer.zero_grad()
diff --git a/mmdet/utils/profiling.py b/mmdet/utils/profiling.py
new file mode 100644
index 0000000000000000000000000000000000000000..4be9222c37e922329d537f883f5587995e27efc6
--- /dev/null
+++ b/mmdet/utils/profiling.py
@@ -0,0 +1,39 @@
+import contextlib
+import sys
+import time
+
+import torch
+
+if sys.version_info >= (3, 7):
+
+ @contextlib.contextmanager
+ def profile_time(trace_name,
+ name,
+ enabled=True,
+ stream=None,
+ end_stream=None):
+ """Print time spent by CPU and GPU.
+
+ Useful as a temporary context manager to find sweet spots of code
+ suitable for async implementation.
+ """
+ if (not enabled) or not torch.cuda.is_available():
+ yield
+ return
+ stream = stream if stream else torch.cuda.current_stream()
+ end_stream = end_stream if end_stream else stream
+ start = torch.cuda.Event(enable_timing=True)
+ end = torch.cuda.Event(enable_timing=True)
+ stream.record_event(start)
+ try:
+ cpu_start = time.monotonic()
+ yield
+ finally:
+ cpu_end = time.monotonic()
+ end_stream.record_event(end)
+ end.synchronize()
+ cpu_time = (cpu_end - cpu_start) * 1000
+ gpu_time = start.elapsed_time(end)
+ msg = f'{trace_name} {name} cpu_time {cpu_time:.2f} ms '
+ msg += f'gpu_time {gpu_time:.2f} ms stream {stream}'
+ print(msg, end_stream)
diff --git a/mmdet/utils/util_mixins.py b/mmdet/utils/util_mixins.py
new file mode 100644
index 0000000000000000000000000000000000000000..69669a3ca943eebe0f138b2784c5b61724196bbe
--- /dev/null
+++ b/mmdet/utils/util_mixins.py
@@ -0,0 +1,104 @@
+"""This module defines the :class:`NiceRepr` mixin class, which defines a
+``__repr__`` and ``__str__`` method that only depend on a custom ``__nice__``
+method, which you must define. This means you only have to overload one
+function instead of two. Furthermore, if the object defines a ``__len__``
+method, then the ``__nice__`` method defaults to something sensible, otherwise
+it is treated as abstract and raises ``NotImplementedError``.
+
+To use simply have your object inherit from :class:`NiceRepr`
+(multi-inheritance should be ok).
+
+This code was copied from the ubelt library: https://github.com/Erotemic/ubelt
+
+Example:
+ >>> # Objects that define __nice__ have a default __str__ and __repr__
+ >>> class Student(NiceRepr):
+ ... def __init__(self, name):
+ ... self.name = name
+ ... def __nice__(self):
+ ... return self.name
+ >>> s1 = Student('Alice')
+ >>> s2 = Student('Bob')
+ >>> print(f's1 = {s1}')
+ >>> print(f's2 = {s2}')
+ s1 =
+ s2 =
+
+Example:
+ >>> # Objects that define __len__ have a default __nice__
+ >>> class Group(NiceRepr):
+ ... def __init__(self, data):
+ ... self.data = data
+ ... def __len__(self):
+ ... return len(self.data)
+ >>> g = Group([1, 2, 3])
+ >>> print(f'g = {g}')
+ g =
+"""
+import warnings
+
+
+class NiceRepr(object):
+ """Inherit from this class and define ``__nice__`` to "nicely" print your
+ objects.
+
+ Defines ``__str__`` and ``__repr__`` in terms of ``__nice__`` function
+ Classes that inherit from :class:`NiceRepr` should redefine ``__nice__``.
+ If the inheriting class has a ``__len__``, method then the default
+ ``__nice__`` method will return its length.
+
+ Example:
+ >>> class Foo(NiceRepr):
+ ... def __nice__(self):
+ ... return 'info'
+ >>> foo = Foo()
+ >>> assert str(foo) == ''
+ >>> assert repr(foo).startswith('>> class Bar(NiceRepr):
+ ... pass
+ >>> bar = Bar()
+ >>> import pytest
+ >>> with pytest.warns(None) as record:
+ >>> assert 'object at' in str(bar)
+ >>> assert 'object at' in repr(bar)
+
+ Example:
+ >>> class Baz(NiceRepr):
+ ... def __len__(self):
+ ... return 5
+ >>> baz = Baz()
+ >>> assert str(baz) == ''
+ """
+
+ def __nice__(self):
+ """str: a "nice" summary string describing this module"""
+ if hasattr(self, '__len__'):
+ # It is a common pattern for objects to use __len__ in __nice__
+ # As a convenience we define a default __nice__ for these objects
+ return str(len(self))
+ else:
+ # In all other cases force the subclass to overload __nice__
+ raise NotImplementedError(
+ f'Define the __nice__ method for {self.__class__!r}')
+
+ def __repr__(self):
+ """str: the string of the module"""
+ try:
+ nice = self.__nice__()
+ classname = self.__class__.__name__
+ return f'<{classname}({nice}) at {hex(id(self))}>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
+
+ def __str__(self):
+ """str: the string of the module"""
+ try:
+ classname = self.__class__.__name__
+ nice = self.__nice__()
+ return f'<{classname}({nice})>'
+ except NotImplementedError as ex:
+ warnings.warn(str(ex), category=RuntimeWarning)
+ return object.__repr__(self)
diff --git a/mmdet/utils/util_random.py b/mmdet/utils/util_random.py
new file mode 100644
index 0000000000000000000000000000000000000000..e313e9947bb3232a9458878fd219e1594ab93d57
--- /dev/null
+++ b/mmdet/utils/util_random.py
@@ -0,0 +1,33 @@
+"""Helpers for random number generators."""
+import numpy as np
+
+
+def ensure_rng(rng=None):
+ """Coerces input into a random number generator.
+
+ If the input is None, then a global random state is returned.
+
+ If the input is a numeric value, then that is used as a seed to construct a
+ random state. Otherwise the input is returned as-is.
+
+ Adapted from [1]_.
+
+ Args:
+ rng (int | numpy.random.RandomState | None):
+ if None, then defaults to the global rng. Otherwise this can be an
+ integer or a RandomState class
+ Returns:
+ (numpy.random.RandomState) : rng -
+ a numpy random number generator
+
+ References:
+ .. [1] https://gitlab.kitware.com/computer-vision/kwarray/blob/master/kwarray/util_random.py#L270 # noqa: E501
+ """
+
+ if rng is None:
+ rng = np.random.mtrand._rand
+ elif isinstance(rng, int):
+ rng = np.random.RandomState(rng)
+ else:
+ rng = rng
+ return rng
diff --git a/mmdet/version.py b/mmdet/version.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3b741aed16212ad1dee277d519b259ae3184b19
--- /dev/null
+++ b/mmdet/version.py
@@ -0,0 +1,19 @@
+# Copyright (c) Open-MMLab. All rights reserved.
+
+__version__ = '2.11.0'
+short_version = __version__
+
+
+def parse_version_info(version_str):
+ version_info = []
+ for x in version_str.split('.'):
+ if x.isdigit():
+ version_info.append(int(x))
+ elif x.find('rc') != -1:
+ patch_version = x.split('rc')
+ version_info.append(int(patch_version[0]))
+ version_info.append(f'rc{patch_version[1]}')
+ return tuple(version_info)
+
+
+version_info = parse_version_info(__version__)
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..85015de1c539c5877d7f2b151aeef327383950a4
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,13 @@
+--find-links https://download.pytorch.org/whl/torch_stable.html
+torch==1.9.0+cpu
+--find-links https://download.openmmlab.com/mmcv/dist/cpu/torch1.9.0/index.html
+mmcv-full==1.4.0
+gradio==3.0.20
+timm
+scikit-image
+imagesize
+torchvision==0.10.0+cpu
+imantics
+terminaltables
+pycocotools
+mmdet==2.16.0
diff --git a/test.py b/test.py
new file mode 100644
index 0000000000000000000000000000000000000000..92332cd994d28041b285151d79a1dc1001749eba
--- /dev/null
+++ b/test.py
@@ -0,0 +1,226 @@
+import argparse
+import os
+import warnings
+
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.cnn import fuse_conv_bn
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (get_dist_info, init_dist, load_checkpoint,
+ wrap_fp16_model)
+
+from mmdet.apis import multi_gpu_test, single_gpu_test
+from walt.datasets import (build_dataloader, build_dataset,
+ replace_ImageToTensor)
+from mmdet.models import build_detector
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(
+ description='MMDet test (and eval) a model')
+ parser.add_argument('config', help='test config file path')
+ parser.add_argument('checkpoint', help='checkpoint file')
+ parser.add_argument('--out', help='output result file in pickle format')
+ parser.add_argument(
+ '--fuse-conv-bn',
+ action='store_true',
+ help='Whether to fuse conv and bn, this will slightly increase'
+ 'the inference speed')
+ parser.add_argument(
+ '--format-only',
+ action='store_true',
+ help='Format the output results without perform evaluation. It is'
+ 'useful when you want to format the result to a specific format and '
+ 'submit it to the test server')
+ parser.add_argument(
+ '--eval',
+ type=str,
+ nargs='+',
+ help='evaluation metrics, which depends on the dataset, e.g., "bbox",'
+ ' "segm", "proposal" for COCO, and "mAP", "recall" for PASCAL VOC')
+ parser.add_argument('--show', action='store_true', help='show results')
+ parser.add_argument(
+ '--show-dir', help='directory where painted images will be saved')
+ parser.add_argument(
+ '--show-score-thr',
+ type=float,
+ default=0.3,
+ help='score threshold (default: 0.3)')
+ parser.add_argument(
+ '--gpu-collect',
+ action='store_true',
+ help='whether to use gpu to collect results.')
+ parser.add_argument(
+ '--tmpdir',
+ help='tmp directory used for collecting results from multiple '
+ 'workers, available when gpu-collect is not specified')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='custom options for evaluation, the key-value pair in xxx=yyy '
+ 'format will be kwargs for dataset.evaluate() function (deprecate), '
+ 'change to --eval-options instead.')
+ parser.add_argument(
+ '--eval-options',
+ nargs='+',
+ action=DictAction,
+ help='custom options for evaluation, the key-value pair in xxx=yyy '
+ 'format will be kwargs for dataset.evaluate() function')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ if args.options and args.eval_options:
+ raise ValueError(
+ '--options and --eval-options cannot be both '
+ 'specified, --options is deprecated in favor of --eval-options')
+ if args.options:
+ warnings.warn('--options is deprecated in favor of --eval-options')
+ args.eval_options = args.options
+ return args
+
+
+def main():
+ args = parse_args()
+
+ assert args.out or args.eval or args.format_only or args.show \
+ or args.show_dir, \
+ ('Please specify at least one operation (save/eval/format/show the '
+ 'results / save the results) with the argument "--out", "--eval"'
+ ', "--format-only", "--show" or "--show-dir"')
+
+ if args.eval and args.format_only:
+ raise ValueError('--eval and --format_only cannot be both specified')
+
+ if args.out is not None and not args.out.endswith(('.pkl', '.pickle')):
+ raise ValueError('The output file must be a pkl file.')
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get('custom_imports', None):
+ from mmcv.utils import import_modules_from_strings
+ import_modules_from_strings(**cfg['custom_imports'])
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+ cfg.model.pretrained = None
+ if cfg.model.get('neck'):
+ if isinstance(cfg.model.neck, list):
+ for neck_cfg in cfg.model.neck:
+ if neck_cfg.get('rfp_backbone'):
+ if neck_cfg.rfp_backbone.get('pretrained'):
+ neck_cfg.rfp_backbone.pretrained = None
+ elif cfg.model.neck.get('rfp_backbone'):
+ if cfg.model.neck.rfp_backbone.get('pretrained'):
+ cfg.model.neck.rfp_backbone.pretrained = None
+
+ # in case the test dataset is concatenated
+ samples_per_gpu = 7
+ if isinstance(cfg.data.test, dict):
+ cfg.data.test.test_mode = True
+ samples_per_gpu = cfg.data.test.pop('samples_per_gpu', 1)
+ if samples_per_gpu > 1:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.test.pipeline = replace_ImageToTensor(
+ cfg.data.test.pipeline)
+ elif isinstance(cfg.data.test, list):
+ for ds_cfg in cfg.data.test:
+ ds_cfg.test_mode = True
+ samples_per_gpu = max(
+ [ds_cfg.pop('samples_per_gpu', 1) for ds_cfg in cfg.data.test])
+ if samples_per_gpu > 1:
+ for ds_cfg in cfg.data.test:
+ ds_cfg.pipeline = replace_ImageToTensor(ds_cfg.pipeline)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+
+ # build the dataloader
+ print(samples_per_gpu,cfg.data.workers_per_gpu,)
+ dataset = build_dataset(cfg.data.test)
+ data_loader = build_dataloader(
+ dataset,
+ samples_per_gpu=samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+
+ # build the model and load checkpoint
+ cfg.model.train_cfg = None
+ model = build_detector(cfg.model, test_cfg=cfg.get('test_cfg'))
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ wrap_fp16_model(model)
+ checkpoint = load_checkpoint(model, args.checkpoint, map_location='cpu')
+ if args.fuse_conv_bn:
+ model = fuse_conv_bn(model)
+ # old versions did not save class info in checkpoints, this walkaround is
+ # for backward compatibility
+ if 'CLASSES' in checkpoint.get('meta', {}):
+ model.CLASSES = checkpoint['meta']['CLASSES']
+ else:
+ model.CLASSES = dataset.CLASSES
+
+ if not distributed:
+ model = MMDataParallel(model, device_ids=[0])
+ outputs = single_gpu_test(model, data_loader, args.show, args.show_dir,
+ args.show_score_thr)
+ else:
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False)
+ outputs = multi_gpu_test(model, data_loader, args.tmpdir,
+ args.gpu_collect)
+ import numpy as np
+
+ rank, _ = get_dist_info()
+ if rank == 0:
+ if args.out:
+ print(f'\nwriting results to {args.out}')
+ mmcv.dump(outputs, args.out)
+ kwargs = {} if args.eval_options is None else args.eval_options
+ if args.format_only:
+ dataset.format_results(outputs, **kwargs)
+ if args.eval:
+ eval_kwargs = cfg.get('evaluation', {}).copy()
+ # hard-code way to remove EvalHook args
+ for key in [
+ 'interval', 'tmpdir', 'start', 'gpu_collect', 'save_best',
+ 'rule'
+ ]:
+ eval_kwargs.pop(key, None)
+ eval_kwargs.update(dict(metric=args.eval, **kwargs))
+ data_evaluated = dataset.evaluate(outputs, **eval_kwargs)
+ np.save(args.checkpoint+'_new1', data_evaluated)
+ print(data_evaluated)
+
+ print(dataset.evaluate(outputs, **eval_kwargs))
+
+
+if __name__ == '__main__':
+ main()
diff --git a/train.py b/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f11191f08da30857ae17d6c498c746b1d184f5
--- /dev/null
+++ b/train.py
@@ -0,0 +1,191 @@
+import argparse
+import copy
+import os
+import os.path as osp
+import time
+import warnings
+
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist
+from mmcv.utils import get_git_hash
+
+from mmdet import __version__
+from mmdet.apis import set_random_seed
+from mmdet.models import build_detector
+from mmdet.utils import collect_env, get_root_logger
+from walt.apis import train_detector
+from walt.datasets import build_dataset
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a detector')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume-from', help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--no-validate',
+ action='store_true',
+ help='whether not to evaluate the checkpoint during training')
+ group_gpus = parser.add_mutually_exclusive_group()
+ group_gpus.add_argument(
+ '--gpus',
+ type=int,
+ help='number of gpus to use '
+ '(only applicable to non-distributed training)')
+ group_gpus.add_argument(
+ '--gpu-ids',
+ type=int,
+ nargs='+',
+ help='ids of gpus to use '
+ '(only applicable to non-distributed training)')
+ parser.add_argument('--seed', type=int, default=None, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file (deprecate), '
+ 'change to --cfg-options instead.')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ if args.options and args.cfg_options:
+ raise ValueError(
+ '--options and --cfg-options cannot be both '
+ 'specified, --options is deprecated in favor of --cfg-options')
+ if args.options:
+ warnings.warn('--options is deprecated in favor of --cfg-options')
+ args.cfg_options = args.options
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get('custom_imports', None):
+ from mmcv.utils import import_modules_from_strings
+ import_modules_from_strings(**cfg['custom_imports'])
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+ if args.gpu_ids is not None:
+ cfg.gpu_ids = args.gpu_ids
+ else:
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ cfg.gpu_ids = range(world_size)
+
+
+ # create work_dir
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ # dump config
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+ # init the logger before other steps
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+ logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+ dash_line)
+ meta['env_info'] = env_info
+ meta['config'] = cfg.pretty_text
+ # log some basic info
+ logger.info(f'Distributed training: {distributed}')
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # set random seeds
+ if args.seed is not None:
+ logger.info(f'Set random seed to {args.seed}, '
+ f'deterministic: {args.deterministic}')
+ set_random_seed(args.seed, deterministic=args.deterministic)
+ cfg.seed = args.seed
+ meta['seed'] = args.seed
+ meta['exp_name'] = osp.basename(args.config)
+
+ model = build_detector(
+ cfg.model,
+ train_cfg=cfg.get('train_cfg'),
+ test_cfg=cfg.get('test_cfg'))
+
+ datasets = [build_dataset(cfg.data.train)]
+ if len(cfg.workflow) == 2:
+ val_dataset = copy.deepcopy(cfg.data.val)
+ val_dataset.pipeline = cfg.data.train.pipeline
+ datasets.append(build_dataset(val_dataset))
+ if cfg.checkpoint_config is not None:
+ # save mmdet version, config file content and class names in
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ mmdet_version=__version__ + get_git_hash()[:7],
+ CLASSES=datasets[0].CLASSES)
+
+ # add an attribute for visualization convenience
+ model.CLASSES = datasets[0].CLASSES
+ train_detector(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/walt/apis/__init__.py b/walt/apis/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..bdd0a928ef5b579f84d0ea7946cb0fea3abcf9f0
--- /dev/null
+++ b/walt/apis/__init__.py
@@ -0,0 +1,6 @@
+from .train import get_root_logger, set_random_seed, train_detector
+
+
+__all__ = [
+ 'get_root_logger', 'set_random_seed', 'train_detector'
+]
diff --git a/walt/apis/train.py b/walt/apis/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c8003d5fdf20a3d6a04ab4a031b053cf56d49c7
--- /dev/null
+++ b/walt/apis/train.py
@@ -0,0 +1,187 @@
+import random
+import warnings
+
+import numpy as np
+import torch
+from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
+from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
+ Fp16OptimizerHook, OptimizerHook, build_optimizer,
+ build_runner)
+from mmcv.utils import build_from_cfg
+
+from mmdet.core import DistEvalHook, EvalHook
+from walt.datasets import (build_dataloader, build_dataset,
+ replace_ImageToTensor)
+from mmdet.utils import get_root_logger
+from mmcv_custom.runner import EpochBasedRunnerAmp
+try:
+ import apex
+except:
+ print('apex is not installed')
+
+
+def set_random_seed(seed, deterministic=False):
+ """Set random seed.
+
+ Args:
+ seed (int): Seed to be used.
+ deterministic (bool): Whether to set the deterministic option for
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
+ to True and `torch.backends.cudnn.benchmark` to False.
+ Default: False.
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ if deterministic:
+ torch.backends.cudnn.deterministic = True
+ torch.backends.cudnn.benchmark = False
+
+
+def train_detector(model,
+ dataset,
+ cfg,
+ distributed=False,
+ validate=False,
+ timestamp=None,
+ meta=None):
+ logger = get_root_logger(cfg.log_level)
+
+ # prepare data loaders
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
+ if 'imgs_per_gpu' in cfg.data:
+ logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
+ 'Please use "samples_per_gpu" instead')
+ if 'samples_per_gpu' in cfg.data:
+ logger.warning(
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
+ f'={cfg.data.imgs_per_gpu} is used in this experiments')
+ else:
+ logger.warning(
+ 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
+ f'{cfg.data.imgs_per_gpu} in this experiments')
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
+
+ data_loaders = [
+ build_dataloader(
+ ds,
+ cfg.data.samples_per_gpu,
+ cfg.data.workers_per_gpu,
+ # cfg.gpus will be ignored if distributed
+ len(cfg.gpu_ids),
+ dist=distributed,
+ seed=cfg.seed) for ds in dataset
+ ]
+
+ # build optimizer
+ optimizer = build_optimizer(model, cfg.optimizer)
+
+ # use apex fp16 optimizer
+ if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
+ if cfg.optimizer_config.get("use_fp16", False):
+ model, optimizer = apex.amp.initialize(
+ model.cuda(), optimizer, opt_level="O1")
+ for m in model.modules():
+ if hasattr(m, "fp16_enabled"):
+ m.fp16_enabled = True
+
+ # put model on gpus
+ if distributed:
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
+ # Sets the `find_unused_parameters` parameter in
+ # torch.nn.parallel.DistributedDataParallel
+ model = MMDistributedDataParallel(
+ model.cuda(),
+ device_ids=[torch.cuda.current_device()],
+ broadcast_buffers=False,
+ find_unused_parameters=find_unused_parameters)
+ else:
+ model = MMDataParallel(
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
+
+ if 'runner' not in cfg:
+ cfg.runner = {
+ 'type': 'EpochBasedRunner',
+ 'max_epochs': cfg.total_epochs
+ }
+ warnings.warn(
+ 'config is now expected to have a `runner` section, '
+ 'please set `runner` in your config.', UserWarning)
+ else:
+ if 'total_epochs' in cfg:
+ assert cfg.total_epochs == cfg.runner.max_epochs
+
+ # build runner
+ runner = build_runner(
+ cfg.runner,
+ default_args=dict(
+ model=model,
+ optimizer=optimizer,
+ work_dir=cfg.work_dir,
+ logger=logger,
+ meta=meta))
+
+ # an ugly workaround to make .log and .log.json filenames the same
+ runner.timestamp = timestamp
+
+ # fp16 setting
+ fp16_cfg = cfg.get('fp16', None)
+ if fp16_cfg is not None:
+ optimizer_config = Fp16OptimizerHook(
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
+ elif distributed and 'type' not in cfg.optimizer_config:
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
+ else:
+ optimizer_config = cfg.optimizer_config
+
+ # register hooks
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
+ cfg.checkpoint_config, cfg.log_config,
+ cfg.get('momentum_config', None))
+ if distributed:
+ if isinstance(runner, EpochBasedRunner):
+ runner.register_hook(DistSamplerSeedHook())
+
+ # register eval hooks
+ if validate:
+ # Support batch_size > 1 in validation
+ val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
+ if val_samples_per_gpu > 1:
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
+ cfg.data.val.pipeline = replace_ImageToTensor(
+ cfg.data.val.pipeline)
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
+ val_dataloader = build_dataloader(
+ val_dataset,
+ samples_per_gpu=val_samples_per_gpu,
+ workers_per_gpu=cfg.data.workers_per_gpu,
+ dist=distributed,
+ shuffle=False)
+ eval_cfg = cfg.get('evaluation', {})
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
+ eval_hook = DistEvalHook if distributed else EvalHook
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
+ '''
+ '''
+
+ # user-defined hooks
+ if cfg.get('custom_hooks', None):
+ custom_hooks = cfg.custom_hooks
+ assert isinstance(custom_hooks, list), \
+ f'custom_hooks expect list type, but got {type(custom_hooks)}'
+ for hook_cfg in cfg.custom_hooks:
+ assert isinstance(hook_cfg, dict), \
+ 'Each item in custom_hooks expects dict type, but got ' \
+ f'{type(hook_cfg)}'
+ hook_cfg = hook_cfg.copy()
+ priority = hook_cfg.pop('priority', 'NORMAL')
+ hook = build_from_cfg(hook_cfg, HOOKS)
+ runner.register_hook(hook, priority=priority)
+
+ if cfg.resume_from:
+ runner.resume(cfg.resume_from)
+ elif cfg.load_from:
+ runner.load_checkpoint(cfg.load_from)
+ runner.run(data_loaders, cfg.workflow)
diff --git a/walt/datasets/__init__.py b/walt/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b6b616c1be7cf9841de63293ee9d41e03a057f
--- /dev/null
+++ b/walt/datasets/__init__.py
@@ -0,0 +1,29 @@
+from .builder import DATASETS, PIPELINES, build_dataloader, build_dataset
+from mmdet.datasets.cityscapes import CityscapesDataset
+from mmdet.datasets.coco import CocoDataset
+from .custom import CustomDatasetLocal
+from mmdet.datasets.custom import CustomDataset
+from mmdet.datasets.dataset_wrappers import (ClassBalancedDataset, ConcatDataset,
+ RepeatDataset)
+from mmdet.datasets.deepfashion import DeepFashionDataset
+from mmdet.datasets.lvis import LVISDataset, LVISV1Dataset, LVISV05Dataset
+from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
+from mmdet.datasets.utils import (NumClassCheckHook, get_loading_pipeline,
+ replace_ImageToTensor)
+from mmdet.datasets.voc import VOCDataset
+from mmdet.datasets.wider_face import WIDERFaceDataset
+from mmdet.datasets.xml_style import XMLDataset
+from .walt_synthetic import WaltSynthDataset
+from .walt_3d import Walt3DDataset
+from .walt import WaltDataset
+__all__ = [
+ 'CustomDataset', 'XMLDataset', 'CocoDataset', 'DeepFashionDataset',
+ 'VOCDataset', 'CityscapesDataset', 'LVISDataset', 'LVISV05Dataset',
+ 'LVISV1Dataset', 'GroupSampler', 'DistributedGroupSampler',
+ 'DistributedSampler', 'build_dataloader', 'ConcatDataset', 'RepeatDataset',
+ 'ClassBalancedDataset', 'Walt3DDataset','WIDERFaceDataset', 'DATASETS', 'PIPELINES',
+ 'build_dataset', 'replace_ImageToTensor', 'get_loading_pipeline',
+ 'WaltSynthDataset', 'WaltDataset', 'NumClassCheckHook'
+]
+
+
diff --git a/walt/datasets/builder.py b/walt/datasets/builder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc0fe466f5bfbf903438a5dc979329debd6517f
--- /dev/null
+++ b/walt/datasets/builder.py
@@ -0,0 +1,143 @@
+import copy
+import platform
+import random
+from functools import partial
+
+import numpy as np
+from mmcv.parallel import collate
+from mmcv.runner import get_dist_info
+from mmcv.utils import Registry, build_from_cfg
+from torch.utils.data import DataLoader
+
+from mmdet.datasets.samplers import DistributedGroupSampler, DistributedSampler, GroupSampler
+
+if platform.system() != 'Windows':
+ # https://github.com/pytorch/pytorch/issues/973
+ import resource
+ rlimit = resource.getrlimit(resource.RLIMIT_NOFILE)
+ hard_limit = rlimit[1]
+ soft_limit = min(4096, hard_limit)
+ resource.setrlimit(resource.RLIMIT_NOFILE, (soft_limit, hard_limit))
+
+DATASETS = Registry('dataset')
+PIPELINES = Registry('pipeline')
+
+
+def _concat_dataset(cfg, default_args=None):
+ from mmdet.datasets.dataset_wrappers import ConcatDataset
+ ann_files = cfg['ann_file']
+ img_prefixes = cfg.get('img_prefix', None)
+ seg_prefixes = cfg.get('seg_prefix', None)
+ proposal_files = cfg.get('proposal_file', None)
+ separate_eval = cfg.get('separate_eval', True)
+
+ datasets = []
+ num_dset = len(ann_files)
+ for i in range(num_dset):
+ data_cfg = copy.deepcopy(cfg)
+ # pop 'separate_eval' since it is not a valid key for common datasets.
+ if 'separate_eval' in data_cfg:
+ data_cfg.pop('separate_eval')
+ data_cfg['ann_file'] = ann_files[i]
+ if isinstance(img_prefixes, (list, tuple)):
+ data_cfg['img_prefix'] = img_prefixes[i]
+ if isinstance(seg_prefixes, (list, tuple)):
+ data_cfg['seg_prefix'] = seg_prefixes[i]
+ if isinstance(proposal_files, (list, tuple)):
+ data_cfg['proposal_file'] = proposal_files[i]
+ datasets.append(build_dataset(data_cfg, default_args))
+
+ return ConcatDataset(datasets, separate_eval)
+
+
+def build_dataset(cfg, default_args=None):
+ from mmdet.datasets.dataset_wrappers import (ConcatDataset, RepeatDataset,
+ ClassBalancedDataset)
+ if isinstance(cfg, (list, tuple)):
+ dataset = ConcatDataset([build_dataset(c, default_args) for c in cfg])
+ elif cfg['type'] == 'ConcatDataset':
+ dataset = ConcatDataset(
+ [build_dataset(c, default_args) for c in cfg['datasets']],
+ cfg.get('separate_eval', True))
+ elif cfg['type'] == 'RepeatDataset':
+ dataset = RepeatDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['times'])
+ elif cfg['type'] == 'ClassBalancedDataset':
+ dataset = ClassBalancedDataset(
+ build_dataset(cfg['dataset'], default_args), cfg['oversample_thr'])
+ elif isinstance(cfg.get('ann_file'), (list, tuple)):
+ dataset = _concat_dataset(cfg, default_args)
+ else:
+ dataset = build_from_cfg(cfg, DATASETS, default_args)
+
+ return dataset
+
+
+def build_dataloader(dataset,
+ samples_per_gpu,
+ workers_per_gpu,
+ num_gpus=1,
+ dist=True,
+ shuffle=True,
+ seed=None,
+ **kwargs):
+ """Build PyTorch DataLoader.
+
+ In distributed training, each GPU/process has a dataloader.
+ In non-distributed training, there is only one dataloader for all GPUs.
+
+ Args:
+ dataset (Dataset): A PyTorch dataset.
+ samples_per_gpu (int): Number of training samples on each GPU, i.e.,
+ batch size of each GPU.
+ workers_per_gpu (int): How many subprocesses to use for data loading
+ for each GPU.
+ num_gpus (int): Number of GPUs. Only used in non-distributed training.
+ dist (bool): Distributed training/test or not. Default: True.
+ shuffle (bool): Whether to shuffle the data at every epoch.
+ Default: True.
+ kwargs: any keyword argument to be used to initialize DataLoader
+
+ Returns:
+ DataLoader: A PyTorch dataloader.
+ """
+ rank, world_size = get_dist_info()
+ if dist:
+ # DistributedGroupSampler will definitely shuffle the data to satisfy
+ # that images on each GPU are in the same group
+ if shuffle:
+ sampler = DistributedGroupSampler(
+ dataset, samples_per_gpu, world_size, rank, seed=seed)
+ else:
+ sampler = DistributedSampler(
+ dataset, world_size, rank, shuffle=False, seed=seed)
+ batch_size = samples_per_gpu
+ num_workers = workers_per_gpu
+ else:
+ sampler = GroupSampler(dataset, samples_per_gpu) if shuffle else None
+ batch_size = num_gpus * samples_per_gpu
+ num_workers = num_gpus * workers_per_gpu
+
+ init_fn = partial(
+ worker_init_fn, num_workers=num_workers, rank=rank,
+ seed=seed) if seed is not None else None
+
+ data_loader = DataLoader(
+ dataset,
+ batch_size=batch_size,
+ sampler=sampler,
+ num_workers=num_workers,
+ collate_fn=partial(collate, samples_per_gpu=samples_per_gpu),
+ pin_memory=False,
+ worker_init_fn=init_fn,
+ **kwargs)
+
+ return data_loader
+
+
+def worker_init_fn(worker_id, num_workers, rank, seed):
+ # The seed of each worker equals to
+ # num_worker * rank + worker_id + user_seed
+ worker_seed = num_workers * rank + worker_id + seed
+ np.random.seed(worker_seed)
+ random.seed(worker_seed)
diff --git a/walt/datasets/coco.py b/walt/datasets/coco.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b8b55461b68d0c5c667d9c07503bb98ab14d65b
--- /dev/null
+++ b/walt/datasets/coco.py
@@ -0,0 +1,519 @@
+__author__ = 'tylin'
+__version__ = '2.0'
+# Interface for accessing the Microsoft COCO dataset.
+
+# Microsoft COCO is a large image dataset designed for object detection,
+# segmentation, and caption generation. pycocotools is a Python API that
+# assists in loading, parsing and visualizing the annotations in COCO.
+# Please visit http://mscoco.org/ for more information on COCO, including
+# for the data, paper, and tutorials. The exact format of the annotations
+# is also described on the COCO website. For example usage of the pycocotools
+# please see pycocotools_demo.ipynb. In addition to this API, please download
+# both the COCO images and annotations in order to run the demo.
+
+# An alternative to using the API is to load the annotations directly
+# into Python dictionary
+# Using the API provides additional utility functions. Note that this API
+# supports both *instance* and *caption* annotations. In the case of
+# captions not all functions are defined (e.g. categories are undefined).
+
+# The following API functions are defined:
+# COCO - COCO api class that loads COCO annotation file and prepare data
+# structures.
+# decodeMask - Decode binary mask M encoded via run-length encoding.
+# encodeMask - Encode binary mask M using run-length encoding.
+# getAnnIds - Get ann ids that satisfy given filter conditions.
+# getCatIds - Get cat ids that satisfy given filter conditions.
+# getImgIds - Get img ids that satisfy given filter conditions.
+# loadAnns - Load anns with the specified ids.
+# loadCats - Load cats with the specified ids.
+# loadImgs - Load imgs with the specified ids.
+# annToMask - Convert segmentation in an annotation to binary mask.
+# showAnns - Display the specified annotations.
+# loadRes - Load algorithm results and create API for accessing them.
+# download - Download COCO images from mscoco.org server.
+# Throughout the API "ann"=annotation, "cat"=category, and "img"=image.
+# Help on each functions can be accessed by: "help COCO>function".
+
+# See also COCO>decodeMask,
+# COCO>encodeMask, COCO>getAnnIds, COCO>getCatIds,
+# COCO>getImgIds, COCO>loadAnns, COCO>loadCats,
+# COCO>loadImgs, COCO>annToMask, COCO>showAnns
+
+# Microsoft COCO Toolbox. version 2.0
+# Data, paper, and tutorials available at: http://mscoco.org/
+# Code written by Piotr Dollar and Tsung-Yi Lin, 2014.
+# Licensed under the Simplified BSD License [see bsd.txt]
+
+import copy
+import itertools
+import json
+import os
+import time
+from collections import defaultdict
+from urllib.request import urlretrieve
+
+import matplotlib.pyplot as plt
+import numpy as np
+from matplotlib.collections import PatchCollection
+from matplotlib.patches import Polygon
+
+from . import mask as maskUtils
+
+
+def _isArrayLike(obj):
+ return hasattr(obj, '__iter__') and hasattr(obj, '__len__')
+
+
+class COCO:
+ def __init__(self, annotation_file=None):
+ """
+ Constructor of Microsoft COCO helper class for reading and visualizing
+ annotations.
+ :param annotation_file (str): location of annotation file
+ :param image_folder (str): location to the folder that hosts images.
+ :return:
+ """
+ # load dataset
+ self.dataset, self.anns, self.cats, self.imgs = dict(), dict(), dict(
+ ), dict()
+ self.imgToAnns, self.catToImgs = defaultdict(list), defaultdict(list)
+ if annotation_file is not None:
+ print('loading annotations into memory...')
+ tic = time.time()
+ with open(annotation_file, 'r') as f:
+ dataset = json.load(f)
+ assert type(
+ dataset
+ ) == dict, 'annotation file format {} not supported'.format(
+ type(dataset))
+ print('Done (t={:0.2f}s)'.format(time.time() - tic))
+ self.dataset = dataset
+ self.createIndex()
+ self.img_ann_map = self.imgToAnns
+ self.cat_img_map = self.catToImgs
+
+ def createIndex(self):
+ # create index
+ print('creating index...')
+ anns, cats, imgs = {}, {}, {}
+ imgToAnns, catToImgs = defaultdict(list), defaultdict(list)
+ if 'annotations' in self.dataset:
+ for ann in self.dataset['annotations']:
+ imgToAnns[ann['image_id']].append(ann)
+ anns[ann['id']] = ann
+
+ if 'images' in self.dataset:
+ for img in self.dataset['images']:
+ imgs[img['id']] = img
+
+ if 'categories' in self.dataset:
+ for cat in self.dataset['categories']:
+ cats[cat['id']] = cat
+
+ if 'annotations' in self.dataset and 'categories' in self.dataset:
+ for ann in self.dataset['annotations']:
+ catToImgs[ann['category_id']].append(ann['image_id'])
+
+ print('index created!')
+
+ # create class members
+ self.anns = anns
+ self.imgToAnns = imgToAnns
+ self.catToImgs = catToImgs
+ self.imgs = imgs
+ self.cats = cats
+
+ def info(self):
+ """
+ Print information about the annotation file.
+ :return:
+ """
+ for key, value in self.dataset['info'].items():
+ print('{}: {}'.format(key, value))
+
+ def getAnnIds(self, imgIds=[], catIds=[], areaRng=[], iscrowd=None):
+ """
+ Get ann ids that satisfy given filter conditions. default skips that
+ filter
+ :param imgIds (int array) : get anns for given imgs
+ catIds (int array) : get anns for given cats
+ areaRng (float array) : get anns for given area range
+ (e.g. [0 inf])
+ iscrowd (boolean) : get anns for given crowd label
+ (False or True)
+ :return: ids (int array) : integer array of ann ids
+ """
+ imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(imgIds) == len(catIds) == len(areaRng) == 0:
+ anns = self.dataset['annotations']
+ else:
+ if not len(imgIds) == 0:
+ lists = [
+ self.imgToAnns[imgId] for imgId in imgIds
+ if imgId in self.imgToAnns
+ ]
+ anns = list(itertools.chain.from_iterable(lists))
+ else:
+ anns = self.dataset['annotations']
+ anns = anns if len(catIds) == 0 else [
+ ann for ann in anns if ann['category_id'] in catIds
+ ]
+ anns = anns if len(areaRng) == 0 else [
+ ann for ann in anns
+ if ann['area'] > areaRng[0] and ann['area'] < areaRng[1]
+ ]
+ if iscrowd is not None:
+ ids = [ann['id'] for ann in anns if ann['iscrowd'] == iscrowd]
+ else:
+ ids = [ann['id'] for ann in anns]
+ return ids
+
+ def get_ann_ids(self, img_ids=[], cat_ids=[], area_rng=[], iscrowd=None):
+ return self.getAnnIds(img_ids, cat_ids, area_rng, iscrowd)
+
+ def getCatIds(self, catNms=[], supNms=[], catIds=[]):
+ """
+ filtering parameters. default skips that filter.
+ :param catNms (str array) : get cats for given cat names
+ :param supNms (str array) : get cats for given supercategory names
+ :param catIds (int array) : get cats for given cat ids
+ :return: ids (int array) : integer array of cat ids
+ """
+ catNms = catNms if _isArrayLike(catNms) else [catNms]
+ supNms = supNms if _isArrayLike(supNms) else [supNms]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(catNms) == len(supNms) == len(catIds) == 0:
+ cats = self.dataset['categories']
+ else:
+ cats = self.dataset['categories']
+ cats = cats if len(catNms) == 0 else [
+ cat for cat in cats if cat['name'] in catNms
+ ]
+ cats = cats if len(supNms) == 0 else [
+ cat for cat in cats if cat['supercategory'] in supNms
+ ]
+ cats = cats if len(catIds) == 0 else [
+ cat for cat in cats if cat['id'] in catIds
+ ]
+ ids = [cat['id'] for cat in cats]
+ return ids
+
+ def get_cat_ids(self, cat_names=[], sup_names=[], cat_ids=[]):
+ return self.getCatIds(cat_names, sup_names, cat_ids)
+
+ def getImgIds(self, imgIds=[], catIds=[]):
+ '''
+ Get img ids that satisfy given filter conditions.
+ :param imgIds (int array) : get imgs for given ids
+ :param catIds (int array) : get imgs with all given cats
+ :return: ids (int array) : integer array of img ids
+ '''
+ imgIds = imgIds if _isArrayLike(imgIds) else [imgIds]
+ catIds = catIds if _isArrayLike(catIds) else [catIds]
+
+ if len(imgIds) == len(catIds) == 0:
+ ids = self.imgs.keys()
+ else:
+ ids = set(imgIds)
+ for i, catId in enumerate(catIds):
+ if i == 0 and len(ids) == 0:
+ ids = set(self.catToImgs[catId])
+ else:
+ ids &= set(self.catToImgs[catId])
+ return list(ids)
+
+ def get_img_ids(self, img_ids=[], cat_ids=[]):
+ return self.getImgIds(img_ids, cat_ids)
+
+ def loadAnns(self, ids=[]):
+ """
+ Load anns with the specified ids.
+ :param ids (int array) : integer ids specifying anns
+ :return: anns (object array) : loaded ann objects
+ """
+ if _isArrayLike(ids):
+ return [self.anns[id] for id in ids]
+ elif type(ids) == int:
+ return [self.anns[ids]]
+
+ load_anns = loadAnns
+
+ def loadCats(self, ids=[]):
+ """
+ Load cats with the specified ids.
+ :param ids (int array) : integer ids specifying cats
+ :return: cats (object array) : loaded cat objects
+ """
+ if _isArrayLike(ids):
+ return [self.cats[id] for id in ids]
+ elif type(ids) == int:
+ return [self.cats[ids]]
+
+ load_cats = loadCats
+
+ def loadImgs(self, ids=[]):
+ """
+ Load anns with the specified ids.
+ :param ids (int array) : integer ids specifying img
+ :return: imgs (object array) : loaded img objects
+ """
+ if _isArrayLike(ids):
+ return [self.imgs[id] for id in ids]
+ elif type(ids) == int:
+ return [self.imgs[ids]]
+
+ load_imgs = loadImgs
+
+ def showAnns(self, anns, draw_bbox=False):
+ """
+ Display the specified annotations.
+ :param anns (array of object): annotations to display
+ :return: None
+ """
+ if len(anns) == 0:
+ return 0
+ if 'segmentation' in anns[0] or 'keypoints' in anns[0]:
+ datasetType = 'instances'
+ elif 'caption' in anns[0]:
+ datasetType = 'captions'
+ else:
+ raise Exception('datasetType not supported')
+ if datasetType == 'instances':
+ ax = plt.gca()
+ ax.set_autoscale_on(False)
+ polygons = []
+ color = []
+ for ann in anns:
+ c = (np.random.random((1, 3)) * 0.6 + 0.4).tolist()[0]
+ if 'segmentation' in ann:
+ if type(ann['segmentation']) == list:
+ # polygon
+ for seg in ann['segmentation']:
+ poly = np.array(seg).reshape(
+ (int(len(seg) / 2), 2))
+ polygons.append(Polygon(poly))
+ color.append(c)
+ else:
+ # mask
+ t = self.imgs[ann['image_id']]
+ if type(ann['segmentation']['counts']) == list:
+ rle = maskUtils.frPyObjects([ann['segmentation']],
+ t['height'],
+ t['width'])
+ else:
+ rle = [ann['segmentation']]
+ m = maskUtils.decode(rle)
+ img = np.ones((m.shape[0], m.shape[1], 3))
+ if ann['iscrowd'] == 1:
+ color_mask = np.array([2.0, 166.0, 101.0]) / 255
+ if ann['iscrowd'] == 0:
+ color_mask = np.random.random((1, 3)).tolist()[0]
+ for i in range(3):
+ img[:, :, i] = color_mask[i]
+ ax.imshow(np.dstack((img, m * 0.5)))
+ if 'keypoints' in ann and type(ann['keypoints']) == list:
+ # turn skeleton into zero-based index
+ sks = np.array(
+ self.loadCats(ann['category_id'])[0]['skeleton']) - 1
+ kp = np.array(ann['keypoints'])
+ x = kp[0::3]
+ y = kp[1::3]
+ v = kp[2::3]
+ for sk in sks:
+ if np.all(v[sk] > 0):
+ plt.plot(x[sk], y[sk], linewidth=3, color=c)
+ plt.plot(x[v > 0],
+ y[v > 0],
+ 'o',
+ markersize=8,
+ markerfacecolor=c,
+ markeredgecolor='k',
+ markeredgewidth=2)
+ plt.plot(x[v > 1],
+ y[v > 1],
+ 'o',
+ markersize=8,
+ markerfacecolor=c,
+ markeredgecolor=c,
+ markeredgewidth=2)
+
+ if draw_bbox:
+ [bbox_x, bbox_y, bbox_w, bbox_h] = ann['bbox']
+ poly = [[bbox_x, bbox_y], [bbox_x, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y + bbox_h],
+ [bbox_x + bbox_w, bbox_y]]
+ np_poly = np.array(poly).reshape((4, 2))
+ polygons.append(Polygon(np_poly))
+ color.append(c)
+
+ p = PatchCollection(polygons,
+ facecolor=color,
+ linewidths=0,
+ alpha=0.4)
+ ax.add_collection(p)
+ p = PatchCollection(polygons,
+ facecolor='none',
+ edgecolors=color,
+ linewidths=2)
+ ax.add_collection(p)
+ elif datasetType == 'captions':
+ for ann in anns:
+ print(ann['caption'])
+
+ def loadRes(self, resFile):
+ """
+ Load result file and return a result api object.
+ :param resFile (str) : file name of result file
+ :return: res (obj) : result api object
+ """
+ res = COCO()
+ res.dataset['images'] = [img for img in self.dataset['images']]
+
+ print('Loading and preparing results...')
+ tic = time.time()
+ if type(resFile) == str:
+ with open(resFile) as f:
+ anns = json.load(f)
+ elif type(resFile) == np.ndarray:
+ anns = self.loadNumpyAnnotations(resFile)
+ else:
+ anns = resFile
+ assert type(anns) == list, 'results in not an array of objects'
+ annsImgIds = [ann['image_id'] for ann in anns]
+ assert set(annsImgIds) == (set(annsImgIds) & set(self.getImgIds())), \
+ 'Results do not correspond to current coco set'
+ if 'caption' in anns[0]:
+ imgIds = set([img['id'] for img in res.dataset['images']]) & set(
+ [ann['image_id'] for ann in anns])
+ res.dataset['images'] = [
+ img for img in res.dataset['images'] if img['id'] in imgIds
+ ]
+ for id, ann in enumerate(anns):
+ ann['id'] = id + 1
+ elif 'bbox' in anns[0] and not anns[0]['bbox'] == []:
+ res.dataset['categories'] = copy.deepcopy(
+ self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ bb = ann['bbox']
+ x1, x2, y1, y2 = [bb[0], bb[0] + bb[2], bb[1], bb[1] + bb[3]]
+ if 'segmentation' not in ann:
+ ann['segmentation'] = [[x1, y1, x1, y2, x2, y2, x2, y1]]
+ ann['area'] = bb[2] * bb[3]
+ ann['id'] = id + 1
+ ann['iscrowd'] = 0
+ elif 'segmentation' in anns[0]:
+ res.dataset['categories'] = copy.deepcopy(
+ self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ # now only support compressed RLE format as segmentation
+ # results
+ ann['area'] = maskUtils.area(ann['segmentation'])
+ if 'bbox' not in ann:
+ ann['bbox'] = maskUtils.toBbox(ann['segmentation'])
+ ann['id'] = id + 1
+ ann['iscrowd'] = 0
+ elif 'keypoints' in anns[0]:
+ res.dataset['categories'] = copy.deepcopy(
+ self.dataset['categories'])
+ for id, ann in enumerate(anns):
+ s = ann['keypoints']
+ x = s[0::3]
+ y = s[1::3]
+ x0, x1, y0, y1 = np.min(x), np.max(x), np.min(y), np.max(y)
+ ann['area'] = (x1 - x0) * (y1 - y0)
+ ann['id'] = id + 1
+ ann['bbox'] = [x0, y0, x1 - x0, y1 - y0]
+ print('DONE (t={:0.2f}s)'.format(time.time() - tic))
+
+ res.dataset['annotations'] = anns
+ res.createIndex()
+ return res
+
+ def download(self, tarDir=None, imgIds=[]):
+ '''
+ Download COCO images from mscoco.org server.
+ :param tarDir (str): COCO results directory name
+ imgIds (list): images to be downloaded
+ :return:
+ '''
+ if tarDir is None:
+ print('Please specify target directory')
+ return -1
+ if len(imgIds) == 0:
+ imgs = self.imgs.values()
+ else:
+ imgs = self.loadImgs(imgIds)
+ N = len(imgs)
+ if not os.path.exists(tarDir):
+ os.makedirs(tarDir)
+ for i, img in enumerate(imgs):
+ tic = time.time()
+ fname = os.path.join(tarDir, img['file_name'])
+ if not os.path.exists(fname):
+ urlretrieve(img['coco_url'], fname)
+ print('downloaded {}/{} images (t={:0.1f}s)'.format(
+ i, N,
+ time.time() - tic))
+
+ def loadNumpyAnnotations(self, data):
+ """
+ Convert result data from a numpy array [Nx7] where each row contains
+ {imageID,x1,y1,w,h,score,class}
+ :param data (numpy.ndarray)
+ :return: annotations (python nested list)
+ """
+ print('Converting ndarray to lists...')
+ assert (type(data) == np.ndarray)
+ print(data.shape)
+ assert (data.shape[1] == 7)
+ N = data.shape[0]
+ ann = []
+ for i in range(N):
+ if i % 1000000 == 0:
+ print('{}/{}'.format(i, N))
+ ann += [{
+ 'image_id': int(data[i, 0]),
+ 'bbox': [data[i, 1], data[i, 2], data[i, 3], data[i, 4]],
+ 'score': data[i, 5],
+ 'category_id': int(data[i, 6]),
+ }]
+ return ann
+
+ def annToRLE(self, ann):
+ """
+ Convert annotation which can be polygons, uncompressed RLE to RLE.
+ :return: binary mask (numpy 2D array)
+ """
+ t = self.imgs[ann['image_id']]
+ h, w = t['height'], t['width']
+ segm = ann['segmentation']
+ if type(segm) == list:
+ # polygon -- a single object might consist of multiple parts
+ # we merge all parts into one mask rle code
+ rles = maskUtils.frPyObjects(segm, h, w)
+ rle = maskUtils.merge(rles)
+ elif type(segm['counts']) == list:
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(segm, h, w)
+ else:
+ # rle
+ rle = ann['segmentation']
+ return rle
+
+ ann_to_rle = annToRLE
+
+ def annToMask(self, ann):
+ """
+ Convert annotation which can be polygons, uncompressed RLE, or RLE to
+ binary mask.
+ :return: binary mask (numpy 2D array)
+ """
+ rle = self.annToRLE(ann)
+ m = maskUtils.decode(rle)
+ return m
+
+ ann_to_mask = annToMask
diff --git a/walt/datasets/cocoeval.py b/walt/datasets/cocoeval.py
new file mode 100644
index 0000000000000000000000000000000000000000..a42a2735b51fa5b8a5f49dfefb48d84121d18484
--- /dev/null
+++ b/walt/datasets/cocoeval.py
@@ -0,0 +1,612 @@
+__author__ = 'tsungyi'
+
+import numpy as np
+import datetime
+import time
+from collections import defaultdict
+import pycocotools.mask as maskUtils
+import copy
+
+
+
+def xywh_to_xyxy(xywh):
+ """Convert [x1 y1 w h] box format to [x1 y1 x2 y2] format."""
+ if isinstance(xywh, (list, tuple)):
+ # Single box given as a list of coordinates
+ assert len(xywh) == 4
+ x1, y1 = xywh[0], xywh[1]
+ x2 = x1 + np.maximum(0., xywh[2] - 1.)
+ y2 = y1 + np.maximum(0., xywh[3] - 1.)
+ return (x1, y1, x2, y2)
+ elif isinstance(xywh, np.ndarray):
+ # Multiple boxes given as a 2D ndarray
+ return np.hstack(
+ (xywh[:, 0:2], xywh[:, 0:2] + np.maximum(0, xywh[:, 2:4] - 1))
+ )
+ else:
+ raise TypeError('Argument xywh must be a list, tuple, or numpy array.')
+
+def get_iou(pred_box, gt_box):
+ """
+ pred_box : the coordinate for predict bounding box
+ gt_box : the coordinate for ground truth bounding box
+ return : the iou score
+ the left-down coordinate of pred_box:(pred_box[0], pred_box[1])
+ the right-up coordinate of pred_box:(pred_box[2], pred_box[3])
+ """
+ pred_box = xywh_to_xyxy(pred_box)
+ gt_box = xywh_to_xyxy(gt_box)
+ # 1.get the coordinate of inters
+ ixmin = max(pred_box[0], gt_box[0])
+ ixmax = min(pred_box[2], gt_box[2])
+ iymin = max(pred_box[1], gt_box[1])
+ iymax = min(pred_box[3], gt_box[3])
+
+ iw = np.maximum(ixmax-ixmin+1., 0.)
+ ih = np.maximum(iymax-iymin+1., 0.)
+
+ # 2. calculate the area of inters
+ inters = iw*ih
+
+ # 3. calculate the area of union
+ uni = ((pred_box[2]-pred_box[0]+1.) * (pred_box[3]-pred_box[1]+1.) +
+ (gt_box[2] - gt_box[0] + 1.) * (gt_box[3] - gt_box[1] + 1.) -
+ inters)
+
+ # 4. calculate the overlaps between pred_box and gt_box
+ iou = inters / uni
+
+ return iou
+
+
+class COCOeval:
+ # Interface for evaluating detection on the Microsoft COCO dataset.
+ #
+ # The usage for CocoEval is as follows:
+ # cocoGt=..., cocoDt=... # load dataset and results
+ # E = CocoEval(cocoGt,cocoDt); # initialize CocoEval object
+ # E.params.recThrs = ...; # set parameters as desired
+ # E.evaluate(); # run per image evaluation
+ # E.accumulate(); # accumulate per image results
+ # E.summarize(); # display summary metrics of results
+ # For example usage see evalDemo.m and http://mscoco.org/.
+ #
+ # The evaluation parameters are as follows (defaults in brackets):
+ # imgIds - [all] N img ids to use for evaluation
+ # catIds - [all] K cat ids to use for evaluation
+ # iouThrs - [.5:.05:.95] T=10 IoU thresholds for evaluation
+ # recThrs - [0:.01:1] R=101 recall thresholds for evaluation
+ # areaRng - [...] A=4 object area ranges for evaluation
+ # maxDets - [1 10 100] M=3 thresholds on max detections per image
+ # iouType - ['segm'] set iouType to 'segm', 'bbox' or 'keypoints'
+ # iouType replaced the now DEPRECATED useSegm parameter.
+ # useCats - [1] if true use category labels for evaluation
+ # Note: if useCats=0 category labels are ignored as in proposal scoring.
+ # Note: multiple areaRngs [Ax2] and maxDets [Mx1] can be specified.
+ #
+ # evaluate(): evaluates detections on every image and every category and
+ # concats the results into the "evalImgs" with fields:
+ # dtIds - [1xD] id for each of the D detections (dt)
+ # gtIds - [1xG] id for each of the G ground truths (gt)
+ # dtMatches - [TxD] matching gt id at each IoU or 0
+ # gtMatches - [TxG] matching dt id at each IoU or 0
+ # dtScores - [1xD] confidence of each dt
+ # gtIgnore - [1xG] ignore flag for each gt
+ # dtIgnore - [TxD] ignore flag for each dt at each IoU
+ #
+ # accumulate(): accumulates the per-image, per-category evaluation
+ # results in "evalImgs" into the dictionary "eval" with fields:
+ # params - parameters used for evaluation
+ # date - date evaluation was performed
+ # counts - [T,R,K,A,M] parameter dimensions (see above)
+ # precision - [TxRxKxAxM] precision for every evaluation setting
+ # recall - [TxKxAxM] max recall for every evaluation setting
+ # Note: precision and recall==-1 for settings with no gt objects.
+ #
+ # See also coco, mask, pycocoDemo, pycocoEvalDemo
+ #
+ # Microsoft COCO Toolbox. version 2.0
+ # Data, paper, and tutorials available at: http://mscoco.org/
+ # Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
+ # Licensed under the Simplified BSD License [see coco/license.txt]
+ def __init__(self, cocoGt=None, cocoDt=None, iouType='segm'):
+ '''
+ Initialize CocoEval using coco APIs for gt and dt
+ :param cocoGt: coco object with ground truth annotations
+ :param cocoDt: coco object with detection results
+ :return: None
+ '''
+ if not iouType:
+ print('iouType not specified. use default iouType segm')
+ self.cocoGt = cocoGt # ground truth COCO API
+ self.cocoDt = cocoDt # detections COCO API
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results [KxAxI] elements
+ self.eval = {} # accumulated evaluation results
+ self._gts = defaultdict(list) # gt for evaluation
+ self._dts = defaultdict(list) # dt for evaluation
+ self.params = Params(iouType=iouType) # parameters
+ self._paramsEval = {} # parameters for evaluation
+ self.stats = [] # result summarization
+ self.ious = {} # ious between all gts and dts
+ self.percentage_occ = 0
+ if not cocoGt is None:
+ self.params.imgIds = sorted(cocoGt.getImgIds())
+ self.params.catIds = sorted(cocoGt.getCatIds())
+
+
+ def _prepare(self):
+ '''
+ Prepare ._gts and ._dts for evaluation based on params
+ :return: None
+ '''
+ def _toMask(anns, coco):
+ # modify ann['segmentation'] by reference
+ for ann in anns:
+ rle = coco.annToRLE(ann)
+ ann['segmentation'] = rle
+ p = self.params
+ if p.useCats:
+ gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
+ dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds, catIds=p.catIds))
+ else:
+ gts=self.cocoGt.loadAnns(self.cocoGt.getAnnIds(imgIds=p.imgIds))
+ dts=self.cocoDt.loadAnns(self.cocoDt.getAnnIds(imgIds=p.imgIds))
+
+ if self.percentage_occ >= 0:
+ gts_new = []
+ indices = []
+ for gt in gts:
+ #print(gt['occ_percentage'], self.percentage_occ)
+ if gt['occ_percentage'] >= self.percentage_occ*10 and gt['occ_percentage'] <(self.percentage_occ+1)*10:
+ for ind, dt in enumerate(dts):
+ if ind in indices or dt['image_id'] != gt['image_id']:
+ continue
+ #print(dt['image_id'], gt['image_id'])
+ if get_iou(gt['bbox'], dt['bbox']) >0.4:
+ indices.append(ind)
+ gts_new.append(gt)
+
+ dts_new = []
+ for i in np.unique(indices):
+ dts_new.append(dts[i])
+
+ #print(len(gts_new), len(gts), len(dts), len(dts_new), len(indices))
+ dts = dts_new
+ gts = gts_new
+ '''
+ '''
+
+ # convert ground truth to mask if iouType == 'segm'
+ if p.iouType == 'segm':
+ _toMask(gts, self.cocoGt)
+ _toMask(dts, self.cocoDt)
+ # set ignore flag
+ for gt in gts:
+ gt['ignore'] = gt['ignore'] if 'ignore' in gt else 0
+ gt['ignore'] = 'iscrowd' in gt and gt['iscrowd']
+ if p.iouType == 'keypoints':
+ gt['ignore'] = (gt['num_keypoints'] == 0) or gt['ignore']
+ self._gts = defaultdict(list) # gt for evaluation
+ self._dts = defaultdict(list) # dt for evaluation
+ for gt in gts:
+ self._gts[gt['image_id'], gt['category_id']].append(gt)
+ for dt in dts:
+ self._dts[dt['image_id'], dt['category_id']].append(dt)
+ self.evalImgs = defaultdict(list) # per-image per-category evaluation results
+ self.eval = {} # accumulated evaluation results
+
+ def evaluate(self):
+ '''
+ Run per image evaluation on given images and store results (a list of dict) in self.evalImgs
+ :return: None
+ '''
+ tic = time.time()
+ print('Running per image evaluation...')
+ p = self.params
+ # add backward compatibility if useSegm is specified in params
+ if not p.useSegm is None:
+ p.iouType = 'segm' if p.useSegm == 1 else 'bbox'
+ print('useSegm (deprecated) is not None. Running {} evaluation'.format(p.iouType))
+ print('Evaluate annotation type *{}*'.format(p.iouType))
+ p.imgIds = list(np.unique(p.imgIds))
+ if p.useCats:
+ p.catIds = list(np.unique(p.catIds))
+ p.maxDets = sorted(p.maxDets)
+ self.params=p
+
+ self._prepare()
+ # loop through images, area range, max detection number
+ catIds = p.catIds if p.useCats else [-1]
+
+ if p.iouType == 'segm' or p.iouType == 'bbox':
+ computeIoU = self.computeIoU
+ elif p.iouType == 'keypoints':
+ computeIoU = self.computeOks
+ self.ious = {(imgId, catId): computeIoU(imgId, catId) \
+ for imgId in p.imgIds
+ for catId in catIds}
+
+ evaluateImg = self.evaluateImg
+ maxDet = p.maxDets[-1]
+ self.evalImgs = [evaluateImg(imgId, catId, areaRng, maxDet)
+ for catId in catIds
+ for areaRng in p.areaRng
+ for imgId in p.imgIds
+ ]
+ self._paramsEval = copy.deepcopy(self.params)
+ toc = time.time()
+ print('DONE (t={:0.2f}s).'.format(toc-tic))
+
+ def computeIoU(self, imgId, catId):
+ # dts_new.append(dt)
+ p = self.params
+ if p.useCats:
+ gt = self._gts[imgId,catId]
+ dt = self._dts[imgId,catId]
+ else:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
+ if len(gt) == 0 and len(dt) ==0:
+ return []
+ inds = np.argsort([-d['score'] for d in dt], kind='mergesort')
+ dt = [dt[i] for i in inds]
+ if len(dt) > p.maxDets[-1]:
+ dt=dt[0:p.maxDets[-1]]
+
+ if p.iouType == 'segm':
+ g = [g['segmentation'] for g in gt]
+ d = [d['segmentation'] for d in dt]
+ elif p.iouType == 'bbox':
+ g = [g['bbox'] for g in gt]
+ d = [d['bbox'] for d in dt]
+ else:
+ raise Exception('unknown iouType for iou computation')
+
+ # compute iou between each dt and gt region
+ iscrowd = [int(o['iscrowd']) for o in gt]
+ ious = maskUtils.iou(d,g,iscrowd)
+ return ious
+
+ def computeOks(self, imgId, catId):
+ p = self.params
+ # dimention here should be Nxm
+ gts = self._gts[imgId, catId]
+ dts = self._dts[imgId, catId]
+ inds = np.argsort([-d['score'] for d in dts], kind='mergesort')
+ dts = [dts[i] for i in inds]
+ if len(dts) > p.maxDets[-1]:
+ dts = dts[0:p.maxDets[-1]]
+ # if len(gts) == 0 and len(dts) == 0:
+ if len(gts) == 0 or len(dts) == 0:
+ return []
+ ious = np.zeros((len(dts), len(gts)))
+ sigmas = p.kpt_oks_sigmas
+ vars = (sigmas * 2)**2
+ k = len(sigmas)
+ # compute oks between each detection and ground truth object
+ for j, gt in enumerate(gts):
+ # create bounds for ignore regions(double the gt bbox)
+ g = np.array(gt['keypoints'])
+ xg = g[0::3]; yg = g[1::3]; vg = g[2::3]
+ k1 = np.count_nonzero(vg > 0)
+ bb = gt['bbox']
+ x0 = bb[0] - bb[2]; x1 = bb[0] + bb[2] * 2
+ y0 = bb[1] - bb[3]; y1 = bb[1] + bb[3] * 2
+ for i, dt in enumerate(dts):
+ d = np.array(dt['keypoints'])
+ xd = d[0::3]; yd = d[1::3]
+ if k1>0:
+ # measure the per-keypoint distance if keypoints visible
+ dx = xd - xg
+ dy = yd - yg
+ else:
+ # measure minimum distance to keypoints in (x0,y0) & (x1,y1)
+ z = np.zeros((k))
+ dx = np.max((z, x0-xd),axis=0)+np.max((z, xd-x1),axis=0)
+ dy = np.max((z, y0-yd),axis=0)+np.max((z, yd-y1),axis=0)
+ e = (dx**2 + dy**2) / vars / (gt['area']+np.spacing(1)) / 2
+ if k1 > 0:
+ e=e[vg > 0]
+ ious[i, j] = np.sum(np.exp(-e)) / e.shape[0]
+ return ious
+
+ def evaluateImg(self, imgId, catId, aRng, maxDet):
+ '''
+ perform evaluation for single category and image
+ :return: dict (single image results)
+ '''
+ p = self.params
+ if p.useCats:
+ gt = self._gts[imgId,catId]
+ dt = self._dts[imgId,catId]
+ else:
+ gt = [_ for cId in p.catIds for _ in self._gts[imgId,cId]]
+ dt = [_ for cId in p.catIds for _ in self._dts[imgId,cId]]
+ if len(gt) == 0 and len(dt) ==0:
+ return None
+
+ for g in gt:
+ if g['ignore'] or (g['area']aRng[1]):
+ g['_ignore'] = 1
+ else:
+ g['_ignore'] = 0
+
+ # sort dt highest score first, sort gt ignore last
+ gtind = np.argsort([g['_ignore'] for g in gt], kind='mergesort')
+ gt = [gt[i] for i in gtind]
+ dtind = np.argsort([-d['score'] for d in dt], kind='mergesort')
+ dt = [dt[i] for i in dtind[0:maxDet]]
+ iscrowd = [int(o['iscrowd']) for o in gt]
+ # load computed ious
+ ious = self.ious[imgId, catId][:, gtind] if len(self.ious[imgId, catId]) > 0 else self.ious[imgId, catId]
+
+ T = len(p.iouThrs)
+ G = len(gt)
+ D = len(dt)
+ gtm = np.zeros((T,G))
+ dtm = np.zeros((T,D))
+ gtIg = np.array([g['_ignore'] for g in gt])
+ dtIg = np.zeros((T,D))
+ if not len(ious)==0:
+ for tind, t in enumerate(p.iouThrs):
+ for dind, d in enumerate(dt):
+ # information about best match so far (m=-1 -> unmatched)
+ iou = min([t,1-1e-10])
+ m = -1
+ for gind, g in enumerate(gt):
+ # if this gt already matched, and not a crowd, continue
+ if gtm[tind,gind]>0 and not iscrowd[gind]:
+ continue
+ # if dt matched to reg gt, and on ignore gt, stop
+ if m>-1 and gtIg[m]==0 and gtIg[gind]==1:
+ break
+ # continue to next gt unless better match made
+ if ious[dind,gind] < iou:
+ continue
+ # if match successful and best so far, store appropriately
+ iou=ious[dind,gind]
+ m=gind
+ # if match made store id of match for both dt and gt
+ if m ==-1:
+ continue
+ dtIg[tind,dind] = gtIg[m]
+ dtm[tind,dind] = gt[m]['id']
+ gtm[tind,m] = d['id']
+ # set unmatched detections outside of area range to ignore
+ a = np.array([d['area']aRng[1] for d in dt]).reshape((1, len(dt)))
+ dtIg = np.logical_or(dtIg, np.logical_and(dtm==0, np.repeat(a,T,0)))
+ # store results for given image and category
+ return {
+ 'image_id': imgId,
+ 'category_id': catId,
+ 'aRng': aRng,
+ 'maxDet': maxDet,
+ 'dtIds': [d['id'] for d in dt],
+ 'gtIds': [g['id'] for g in gt],
+ 'dtMatches': dtm,
+ 'gtMatches': gtm,
+ 'dtScores': [d['score'] for d in dt],
+ 'gtIgnore': gtIg,
+ 'dtIgnore': dtIg,
+ }
+
+ def accumulate(self, p = None):
+ '''
+ Accumulate per image evaluation results and store the result in self.eval
+ :param p: input params for evaluation
+ :return: None
+ '''
+ print('Accumulating evaluation results...')
+ tic = time.time()
+ if not self.evalImgs:
+ print('Please run evaluate() first')
+ # allows input customized parameters
+ if p is None:
+ p = self.params
+ p.catIds = p.catIds if p.useCats == 1 else [-1]
+ T = len(p.iouThrs)
+ R = len(p.recThrs)
+ K = len(p.catIds) if p.useCats else 1
+ A = len(p.areaRng)
+ M = len(p.maxDets)
+ precision = -np.ones((T,R,K,A,M)) # -1 for the precision of absent categories
+ recall = -np.ones((T,K,A,M))
+ scores = -np.ones((T,R,K,A,M))
+
+ # create dictionary for future indexing
+ _pe = self._paramsEval
+ catIds = _pe.catIds if _pe.useCats else [-1]
+ setK = set(catIds)
+ setA = set(map(tuple, _pe.areaRng))
+ setM = set(_pe.maxDets)
+ setI = set(_pe.imgIds)
+ # get inds to evaluate
+ k_list = [n for n, k in enumerate(p.catIds) if k in setK]
+ m_list = [m for n, m in enumerate(p.maxDets) if m in setM]
+ a_list = [n for n, a in enumerate(map(lambda x: tuple(x), p.areaRng)) if a in setA]
+ i_list = [n for n, i in enumerate(p.imgIds) if i in setI]
+ I0 = len(_pe.imgIds)
+ A0 = len(_pe.areaRng)
+ # retrieve E at each category, area range, and max number of detections
+ for k, k0 in enumerate(k_list):
+ Nk = k0*A0*I0
+ for a, a0 in enumerate(a_list):
+ Na = a0*I0
+ for m, maxDet in enumerate(m_list):
+ E = [self.evalImgs[Nk + Na + i] for i in i_list]
+ E = [e for e in E if not e is None]
+ if len(E) == 0:
+ continue
+ dtScores = np.concatenate([e['dtScores'][0:maxDet] for e in E])
+
+ # different sorting method generates slightly different results.
+ # mergesort is used to be consistent as Matlab implementation.
+ inds = np.argsort(-dtScores, kind='mergesort')
+ dtScoresSorted = dtScores[inds]
+
+ dtm = np.concatenate([e['dtMatches'][:,0:maxDet] for e in E], axis=1)[:,inds]
+ dtIg = np.concatenate([e['dtIgnore'][:,0:maxDet] for e in E], axis=1)[:,inds]
+ gtIg = np.concatenate([e['gtIgnore'] for e in E])
+ npig = np.count_nonzero(gtIg==0 )
+ if npig == 0:
+ continue
+ tps = np.logical_and( dtm, np.logical_not(dtIg) )
+ fps = np.logical_and(np.logical_not(dtm), np.logical_not(dtIg) )
+
+ tp_sum = np.cumsum(tps, axis=1).astype(dtype=np.float)
+ fp_sum = np.cumsum(fps, axis=1).astype(dtype=np.float)
+ for t, (tp, fp) in enumerate(zip(tp_sum, fp_sum)):
+ tp = np.array(tp)
+ fp = np.array(fp)
+ nd = len(tp)
+ rc = tp / npig
+ pr = tp / (fp+tp+np.spacing(1))
+ q = np.zeros((R,))
+ ss = np.zeros((R,))
+
+ if nd:
+ recall[t,k,a,m] = rc[-1]
+ else:
+ recall[t,k,a,m] = 0
+
+ # numpy is slow without cython optimization for accessing elements
+ # use python array gets significant speed improvement
+ pr = pr.tolist(); q = q.tolist()
+
+ for i in range(nd-1, 0, -1):
+ if pr[i] > pr[i-1]:
+ pr[i-1] = pr[i]
+
+ inds = np.searchsorted(rc, p.recThrs, side='left')
+ try:
+ for ri, pi in enumerate(inds):
+ q[ri] = pr[pi]
+ ss[ri] = dtScoresSorted[pi]
+ except:
+ pass
+ precision[t,:,k,a,m] = np.array(q)
+ scores[t,:,k,a,m] = np.array(ss)
+ self.eval = {
+ 'params': p,
+ 'counts': [T, R, K, A, M],
+ 'date': datetime.datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
+ 'precision': precision,
+ 'recall': recall,
+ 'scores': scores,
+ }
+ toc = time.time()
+ print('DONE (t={:0.2f}s).'.format( toc-tic))
+
+ def summarize(self):
+ '''
+ Compute and display summary metrics for evaluation results.
+ Note this functin can *only* be applied on the default parameter setting
+ '''
+ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
+ p = self.params
+ iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
+ titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
+ typeStr = '(AP)' if ap==1 else '(AR)'
+ iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
+ if iouThr is None else '{:0.2f}'.format(iouThr)
+
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
+ if ap == 1:
+ # dimension of precision: [TxRxKxAxM]
+ s = self.eval['precision']
+ # IoU
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:,:,:,aind,mind]
+ else:
+ # dimension of recall: [TxKxAxM]
+ s = self.eval['recall']
+ if iouThr is not None:
+ t = np.where(iouThr == p.iouThrs)[0]
+ s = s[t]
+ s = s[:,:,aind,mind]
+ if len(s[s>-1])==0:
+ mean_s = -1
+ else:
+ mean_s = np.mean(s[s>-1])
+ print(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s))
+ return mean_s
+ def _summarizeDets():
+ stats = np.zeros((12,))
+ stats[0] = _summarize(1)
+ stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
+ stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
+ stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
+ stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
+ stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
+ stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
+ stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
+ stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
+ return stats
+ def _summarizeKps():
+ stats = np.zeros((10,))
+ stats[0] = _summarize(1, maxDets=20)
+ stats[1] = _summarize(1, maxDets=20, iouThr=.5)
+ stats[2] = _summarize(1, maxDets=20, iouThr=.75)
+ stats[3] = _summarize(1, maxDets=20, areaRng='medium')
+ stats[4] = _summarize(1, maxDets=20, areaRng='large')
+ stats[5] = _summarize(0, maxDets=20)
+ stats[6] = _summarize(0, maxDets=20, iouThr=.5)
+ stats[7] = _summarize(0, maxDets=20, iouThr=.75)
+ stats[8] = _summarize(0, maxDets=20, areaRng='medium')
+ stats[9] = _summarize(0, maxDets=20, areaRng='large')
+ return stats
+ if not self.eval:
+ raise Exception('Please run accumulate() first')
+ iouType = self.params.iouType
+ if iouType == 'segm' or iouType == 'bbox':
+ summarize = _summarizeDets
+ elif iouType == 'keypoints':
+ summarize = _summarizeKps
+ self.stats = summarize()
+
+ def __str__(self):
+ self.summarize()
+
+class Params:
+ '''
+ Params for coco evaluation api
+ '''
+ def setDetParams(self):
+ self.imgIds = []
+ self.catIds = []
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
+ self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
+ self.maxDets = [1, 10, 100]
+ self.areaRng = [[0 ** 2, 1e5 ** 2], [0 ** 2, 32 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
+ self.areaRngLbl = ['all', 'small', 'medium', 'large']
+ self.useCats = 1
+
+ def setKpParams(self):
+ self.imgIds = []
+ self.catIds = []
+ # np.arange causes trouble. the data point on arange is slightly larger than the true value
+ self.iouThrs = np.linspace(.5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ self.recThrs = np.linspace(.0, 1.00, int(np.round((1.00 - .0) / .01)) + 1, endpoint=True)
+ self.maxDets = [20]
+ self.areaRng = [[0 ** 2, 1e5 ** 2], [32 ** 2, 96 ** 2], [96 ** 2, 1e5 ** 2]]
+ self.areaRngLbl = ['all', 'medium', 'large']
+ self.useCats = 1
+ self.kpt_oks_sigmas = np.array([.26, .25, .25, .35, .35, .79, .79, .72, .72, .62,.62, 1.07, 1.07, .87, .87, .89, .89])/10.0
+
+ def __init__(self, iouType='segm'):
+ if iouType == 'segm' or iouType == 'bbox':
+ self.setDetParams()
+ elif iouType == 'keypoints':
+ self.setKpParams()
+ else:
+ raise Exception('iouType not supported')
+ self.iouType = iouType
+ # useSegm is deprecated
+ self.useSegm = None
diff --git a/walt/datasets/custom.py b/walt/datasets/custom.py
new file mode 100644
index 0000000000000000000000000000000000000000..572742aa2e9c57cb6de2aac17939abf4a18216a3
--- /dev/null
+++ b/walt/datasets/custom.py
@@ -0,0 +1,324 @@
+import os.path as osp
+import warnings
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+from mmcv.utils import print_log
+from torch.utils.data import Dataset
+
+from mmdet.core import eval_map, eval_recalls
+from .builder import DATASETS
+from .pipelines import Compose
+
+
+@DATASETS.register_module()
+class CustomDatasetLocal(Dataset):
+ """Custom dataset for detection.
+
+ The annotation format is shown as follows. The `ann` field is optional for
+ testing.
+
+ .. code-block:: none
+
+ [
+ {
+ 'filename': 'a.jpg',
+ 'width': 1280,
+ 'height': 720,
+ 'ann': {
+ 'bboxes': (n, 4) in (x1, y1, x2, y2) order.
+ 'labels': (n, ),
+ 'bboxes_ignore': (k, 4), (optional field)
+ 'labels_ignore': (k, 4) (optional field)
+ }
+ },
+ ...
+ ]
+
+ Args:
+ ann_file (str): Annotation file path.
+ pipeline (list[dict]): Processing pipeline.
+ classes (str | Sequence[str], optional): Specify classes to load.
+ If is None, ``cls.CLASSES`` will be used. Default: None.
+ data_root (str, optional): Data root for ``ann_file``,
+ ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified.
+ test_mode (bool, optional): If set True, annotation will not be loaded.
+ filter_empty_gt (bool, optional): If set true, images without bounding
+ boxes of the dataset's classes will be filtered out. This option
+ only works when `test_mode=False`, i.e., we never filter images
+ during tests.
+ """
+
+ CLASSES = None
+
+ def __init__(self,
+ ann_file,
+ pipeline,
+ classes=None,
+ data_root=None,
+ img_prefix='',
+ seg_prefix=None,
+ proposal_file=None,
+ test_mode=False,
+ filter_empty_gt=True):
+ self.ann_file = ann_file
+ self.data_root = data_root
+ self.img_prefix = img_prefix
+ self.seg_prefix = seg_prefix
+ self.proposal_file = proposal_file
+ self.test_mode = test_mode
+ self.filter_empty_gt = filter_empty_gt
+ self.CLASSES = self.get_classes(classes)
+
+ # join paths if data_root is specified
+ if self.data_root is not None:
+ if not osp.isabs(self.ann_file):
+ self.ann_file = osp.join(self.data_root, self.ann_file)
+ if not (self.img_prefix is None or osp.isabs(self.img_prefix)):
+ self.img_prefix = osp.join(self.data_root, self.img_prefix)
+ if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)):
+ self.seg_prefix = osp.join(self.data_root, self.seg_prefix)
+ if not (self.proposal_file is None
+ or osp.isabs(self.proposal_file)):
+ self.proposal_file = osp.join(self.data_root,
+ self.proposal_file)
+ # load annotations (and proposals)
+ self.data_infos = self.load_annotations(self.ann_file)
+
+ if self.proposal_file is not None:
+ self.proposals = self.load_proposals(self.proposal_file)
+ else:
+ self.proposals = None
+
+ # filter images too small and containing no annotations
+ if not test_mode:
+ valid_inds = self._filter_imgs()
+ self.data_infos = [self.data_infos[i] for i in valid_inds]
+ if self.proposals is not None:
+ self.proposals = [self.proposals[i] for i in valid_inds]
+ # set group flag for the sampler
+ self._set_group_flag()
+
+ # processing pipeline
+ self.pipeline = Compose(pipeline)
+
+ def __len__(self):
+ """Total number of samples of data."""
+ return len(self.data_infos)
+
+ def load_annotations(self, ann_file):
+ """Load annotation from annotation file."""
+ return mmcv.load(ann_file)
+
+ def load_proposals(self, proposal_file):
+ """Load proposal from proposal file."""
+ return mmcv.load(proposal_file)
+
+ def get_ann_info(self, idx):
+ """Get annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ return self.data_infos[idx]['ann']
+
+ def get_cat_ids(self, idx):
+ """Get category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist()
+
+ def pre_pipeline(self, results):
+ """Prepare results dict for pipeline."""
+ results['img_prefix'] = self.img_prefix
+ results['seg_prefix'] = self.seg_prefix
+ results['proposal_file'] = self.proposal_file
+ results['bbox_fields'] = []
+ results['bbox3d_fields'] = []
+ results['mask_fields'] = []
+ results['seg_fields'] = []
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small."""
+ if self.filter_empty_gt:
+ warnings.warn(
+ 'CustomDataset does not support filtering empty gt images.')
+ valid_inds = []
+ for i, img_info in enumerate(self.data_infos):
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ return valid_inds
+
+ def _set_group_flag(self):
+ """Set flag according to image aspect ratio.
+
+ Images with aspect ratio greater than 1 will be set as group 1,
+ otherwise group 0.
+ """
+ self.flag = np.zeros(len(self), dtype=np.uint8)
+ for i in range(len(self)):
+ img_info = self.data_infos[i]
+ if img_info['width'] / img_info['height'] > 1:
+ self.flag[i] = 1
+
+ def _rand_another(self, idx):
+ """Get another random index from the same group as the given index."""
+ pool = np.where(self.flag == self.flag[idx])[0]
+ return np.random.choice(pool)
+
+ def __getitem__(self, idx):
+ """Get training/test data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training/test data (with annotation if `test_mode` is set \
+ True).
+ """
+
+ if self.test_mode:
+ return self.prepare_test_img(idx)
+ while True:
+ data = self.prepare_train_img(idx)
+ if data is None:
+ idx = self._rand_another(idx)
+ continue
+ return data
+
+ def prepare_train_img(self, idx):
+ """Get training data and annotations after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Training data and annotation after pipeline with new keys \
+ introduced by pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ ann_info = self.get_ann_info(idx)
+ results = dict(img_info=img_info, ann_info=ann_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ def prepare_test_img(self, idx):
+ """Get testing data after pipeline.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Testing data after pipeline with new keys introduced by \
+ pipeline.
+ """
+
+ img_info = self.data_infos[idx]
+ results = dict(img_info=img_info)
+ if self.proposals is not None:
+ results['proposals'] = self.proposals[idx]
+ self.pre_pipeline(results)
+ return self.pipeline(results)
+
+ @classmethod
+ def get_classes(cls, classes=None):
+ """Get class names of current dataset.
+
+ Args:
+ classes (Sequence[str] | str | None): If classes is None, use
+ default CLASSES defined by builtin dataset. If classes is a
+ string, take it as a file name. The file contains the name of
+ classes where each line contains one class name. If classes is
+ a tuple or list, override the CLASSES defined by the dataset.
+
+ Returns:
+ tuple[str] or list[str]: Names of categories of the dataset.
+ """
+ if classes is None:
+ return cls.CLASSES
+
+ if isinstance(classes, str):
+ # take it as a file path
+ class_names = mmcv.list_from_file(classes)
+ elif isinstance(classes, (tuple, list)):
+ class_names = classes
+ else:
+ raise ValueError(f'Unsupported type {type(classes)} of classes.')
+
+ return class_names
+
+ def format_results(self, results, **kwargs):
+ """Place holder to format result to dataset specific output."""
+
+ def evaluate(self,
+ results,
+ metric='mAP',
+ logger=None,
+ proposal_nums=(100, 300, 1000),
+ iou_thr=0.5,
+ scale_ranges=None):
+ """Evaluate the dataset.
+
+ Args:
+ results (list): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated.
+ logger (logging.Logger | None | str): Logger used for printing
+ related information during evaluation. Default: None.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thr (float | list[float]): IoU threshold. Default: 0.5.
+ scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP.
+ Default: None.
+ """
+
+ if not isinstance(metric, str):
+ assert len(metric) == 1
+ metric = metric[0]
+ allowed_metrics = ['mAP', 'recall']
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ annotations = [self.get_ann_info(i) for i in range(len(self))]
+ eval_results = OrderedDict()
+ iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr
+ if metric == 'mAP':
+ assert isinstance(iou_thrs, list)
+ mean_aps = []
+ for iou_thr in iou_thrs:
+ print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}')
+ mean_ap, _ = eval_map(
+ results,
+ annotations,
+ scale_ranges=scale_ranges,
+ iou_thr=iou_thr,
+ dataset=self.CLASSES,
+ logger=logger)
+ mean_aps.append(mean_ap)
+ eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3)
+ eval_results['mAP'] = sum(mean_aps) / len(mean_aps)
+ elif metric == 'recall':
+ gt_bboxes = [ann['bboxes'] for ann in annotations]
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thr, logger=logger)
+ for i, num in enumerate(proposal_nums):
+ for j, iou in enumerate(iou_thrs):
+ eval_results[f'recall@{num}@{iou}'] = recalls[i, j]
+ if recalls.shape[1] > 1:
+ ar = recalls.mean(axis=1)
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ return eval_results
diff --git a/walt/datasets/mask.py b/walt/datasets/mask.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb7b2bcd0f74f48f8eb0cb249334dc9095138976
--- /dev/null
+++ b/walt/datasets/mask.py
@@ -0,0 +1,110 @@
+__author__ = 'tsungyi'
+
+import pycocotools._mask as _mask
+
+# Interface for manipulating masks stored in RLE format.
+#
+# RLE is a simple yet efficient format for storing binary masks. RLE
+# first divides a vector (or vectorized image) into a series of piecewise
+# constant regions and then for each piece simply stores the length of
+# that piece. For example, given M=[0 0 1 1 1 0 1] the RLE counts would
+# be [2 3 1 1], or for M=[1 1 1 1 1 1 0] the counts would be [0 6 1]
+# (note that the odd counts are always the numbers of zeros). Instead of
+# storing the counts directly, additional compression is achieved with a
+# variable bitrate representation based on a common scheme called LEB128.
+#
+# Compression is greatest given large piecewise constant regions.
+# Specifically, the size of the RLE is proportional to the number of
+# *boundaries* in M (or for an image the number of boundaries in the y
+# direction). Assuming fairly simple shapes, the RLE representation is
+# O(sqrt(n)) where n is number of pixels in the object. Hence space usage
+# is substantially lower, especially for large simple objects (large n).
+#
+# Many common operations on masks can be computed directly using the RLE
+# (without need for decoding). This includes computations such as area,
+# union, intersection, etc. All of these operations are linear in the
+# size of the RLE, in other words they are O(sqrt(n)) where n is the area
+# of the object. Computing these operations on the original mask is O(n).
+# Thus, using the RLE can result in substantial computational savings.
+#
+# The following API functions are defined:
+# encode - Encode binary masks using RLE.
+# decode - Decode binary masks encoded via RLE.
+# merge - Compute union or intersection of encoded masks.
+# iou - Compute intersection over union between masks.
+# area - Compute area of encoded masks.
+# toBbox - Get bounding boxes surrounding encoded masks.
+# frPyObjects - Convert polygon, bbox, and uncompressed RLE to encoded
+# RLE mask.
+#
+# Usage:
+# Rs = encode( masks )
+# masks = decode( Rs )
+# R = merge( Rs, intersect=false )
+# o = iou( dt, gt, iscrowd )
+# a = area( Rs )
+# bbs = toBbox( Rs )
+# Rs = frPyObjects( [pyObjects], h, w )
+#
+# In the API the following formats are used:
+# Rs - [dict] Run-length encoding of binary masks
+# R - dict Run-length encoding of binary mask
+# masks - [hxwxn] Binary mask(s) (must have type np.ndarray(dtype=uint8)
+# in column-major order)
+# iscrowd - [nx1] list of np.ndarray. 1 indicates corresponding gt image has
+# crowd region to ignore
+# bbs - [nx4] Bounding box(es) stored as [x y w h]
+# poly - Polygon stored as [[x1 y1 x2 y2...],[x1 y1 ...],...] (2D list)
+# dt,gt - May be either bounding boxes or encoded masks
+# Both poly and bbs are 0-indexed (bbox=[0 0 1 1] encloses first pixel).
+#
+# Finally, a note about the intersection over union (iou) computation.
+# The standard iou of a ground truth (gt) and detected (dt) object is
+# iou(gt,dt) = area(intersect(gt,dt)) / area(union(gt,dt))
+# For "crowd" regions, we use a modified criteria. If a gt object is
+# marked as "iscrowd", we allow a dt to match any subregion of the gt.
+# Choosing gt' in the crowd gt that best matches the dt can be done using
+# gt'=intersect(dt,gt). Since by definition union(gt',dt)=dt, computing
+# iou(gt,dt,iscrowd) = iou(gt',dt) = area(intersect(gt,dt)) / area(dt)
+# For crowd gt regions we use this modified criteria above for the iou.
+#
+# To compile run "python setup.py build_ext --inplace"
+# Please do not contact us for help with compiling.
+#
+# Microsoft COCO Toolbox. version 2.0
+# Data, paper, and tutorials available at: http://mscoco.org/
+# Code written by Piotr Dollar and Tsung-Yi Lin, 2015.
+# Licensed under the Simplified BSD License [see coco/license.txt]
+
+iou = _mask.iou
+merge = _mask.merge
+frPyObjects = _mask.frPyObjects
+
+
+def encode(bimask):
+ if len(bimask.shape) == 3:
+ return _mask.encode(bimask)
+ elif len(bimask.shape) == 2:
+ h, w = bimask.shape
+ return _mask.encode(bimask.reshape((h, w, 1), order='F'))[0]
+
+
+def decode(rleObjs):
+ if type(rleObjs) == list:
+ return _mask.decode(rleObjs)
+ else:
+ return _mask.decode([rleObjs])[:, :, 0]
+
+
+def area(rleObjs):
+ if type(rleObjs) == list:
+ return _mask.area(rleObjs)
+ else:
+ return _mask.area([rleObjs])[0]
+
+
+def toBbox(rleObjs):
+ if type(rleObjs) == list:
+ return _mask.toBbox(rleObjs)
+ else:
+ return _mask.toBbox([rleObjs])[0]
diff --git a/walt/datasets/pipelines/__init__.py b/walt/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6f424debd1623e7511dd77da464a6639d816745
--- /dev/null
+++ b/walt/datasets/pipelines/__init__.py
@@ -0,0 +1,25 @@
+from .auto_augment import (AutoAugment, BrightnessTransform, ColorTransform,
+ ContrastTransform, EqualizeTransform, Rotate, Shear,
+ Translate)
+from .compose import Compose
+from .formating import (Collect, DefaultFormatBundle, ImageToTensor,
+ ToDataContainer, ToTensor, Transpose, to_tensor)
+from .instaboost import InstaBoost
+from .loading import (LoadAnnotations, LoadImageFromFile, LoadImageFromWebcam,
+ LoadMultiChannelImageFromFiles, LoadProposals)
+from .test_time_aug import MultiScaleFlipAug
+from .transforms import (Albu, CutOut, Expand, MinIoURandomCrop, Normalize,
+ Pad, PhotoMetricDistortion, RandomCenterCropPad,
+ RandomCrop, RandomFlip, Resize, SegRescale)
+
+__all__ = [
+ 'Compose', 'to_tensor', 'ToTensor', 'ImageToTensor', 'ToDataContainer',
+ 'Transpose', 'Collect', 'DefaultFormatBundle', 'LoadAnnotations',
+ 'LoadImageFromFile', 'LoadImageFromWebcam',
+ 'LoadMultiChannelImageFromFiles', 'LoadProposals', 'MultiScaleFlipAug',
+ 'Resize', 'RandomFlip', 'Pad', 'RandomCrop', 'Normalize', 'SegRescale',
+ 'MinIoURandomCrop', 'Expand', 'PhotoMetricDistortion', 'Albu',
+ 'InstaBoost', 'RandomCenterCropPad', 'AutoAugment', 'CutOut', 'Shear',
+ 'Rotate', 'ColorTransform', 'EqualizeTransform', 'BrightnessTransform',
+ 'ContrastTransform', 'Translate'
+]
diff --git a/walt/datasets/pipelines/auto_augment.py b/walt/datasets/pipelines/auto_augment.py
new file mode 100644
index 0000000000000000000000000000000000000000..e19adaec18a96cac4dbe1d8c2c9193e9901be1fb
--- /dev/null
+++ b/walt/datasets/pipelines/auto_augment.py
@@ -0,0 +1,890 @@
+import copy
+
+import cv2
+import mmcv
+import numpy as np
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+_MAX_LEVEL = 10
+
+
+def level_to_value(level, max_value):
+ """Map from level to values based on max_value."""
+ return (level / _MAX_LEVEL) * max_value
+
+
+def enhance_level_to_value(level, a=1.8, b=0.1):
+ """Map from level to values."""
+ return (level / _MAX_LEVEL) * a + b
+
+
+def random_negative(value, random_negative_prob):
+ """Randomly negate value based on random_negative_prob."""
+ return -value if np.random.rand() < random_negative_prob else value
+
+
+def bbox2fields():
+ """The key correspondence from bboxes to labels, masks and
+ segmentations."""
+ bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+ bbox2seg = {
+ 'gt_bboxes': 'gt_semantic_seg',
+ }
+ return bbox2label, bbox2mask, bbox2seg
+
+
+@PIPELINES.register_module()
+class AutoAugment(object):
+ """Auto augmentation.
+
+ This data augmentation is proposed in `Learning Data Augmentation
+ Strategies for Object Detection `_.
+
+ TODO: Implement 'Shear', 'Sharpness' and 'Rotate' transforms
+
+ Args:
+ policies (list[list[dict]]): The policies of auto augmentation. Each
+ policy in ``policies`` is a specific augmentation policy, and is
+ composed by several augmentations (dict). When AutoAugment is
+ called, a random policy in ``policies`` will be selected to
+ augment images.
+
+ Examples:
+ >>> replace = (104, 116, 124)
+ >>> policies = [
+ >>> [
+ >>> dict(type='Sharpness', prob=0.0, level=8),
+ >>> dict(
+ >>> type='Shear',
+ >>> prob=0.4,
+ >>> level=0,
+ >>> replace=replace,
+ >>> axis='x')
+ >>> ],
+ >>> [
+ >>> dict(
+ >>> type='Rotate',
+ >>> prob=0.6,
+ >>> level=10,
+ >>> replace=replace),
+ >>> dict(type='Color', prob=1.0, level=6)
+ >>> ]
+ >>> ]
+ >>> augmentation = AutoAugment(policies)
+ >>> img = np.ones(100, 100, 3)
+ >>> gt_bboxes = np.ones(10, 4)
+ >>> results = dict(img=img, gt_bboxes=gt_bboxes)
+ >>> results = augmentation(results)
+ """
+
+ def __init__(self, policies):
+ assert isinstance(policies, list) and len(policies) > 0, \
+ 'Policies must be a non-empty list.'
+ for policy in policies:
+ assert isinstance(policy, list) and len(policy) > 0, \
+ 'Each policy in policies must be a non-empty list.'
+ for augment in policy:
+ assert isinstance(augment, dict) and 'type' in augment, \
+ 'Each specific augmentation must be a dict with key' \
+ ' "type".'
+
+ self.policies = copy.deepcopy(policies)
+ self.transforms = [Compose(policy) for policy in self.policies]
+
+ def __call__(self, results):
+ transform = np.random.choice(self.transforms)
+ return transform(results)
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}(policies={self.policies})'
+
+
+@PIPELINES.register_module()
+class Shear(object):
+ """Apply Shear Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range [0,_MAX_LEVEL].
+ img_fill_val (int | float | tuple): The filled values for image border.
+ If float, the same fill value will be used for all the three
+ channels of image. If tuple, the should be 3 elements.
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for performing Shear and should be in
+ range [0, 1].
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ max_shear_magnitude (float): The maximum magnitude for Shear
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative. Should be in range [0,1]
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+
+ def __init__(self,
+ level,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ direction='horizontal',
+ max_shear_magnitude=0.3,
+ random_negative_prob=0.5,
+ interpolation='bilinear'):
+ assert isinstance(level, (int, float)), 'The level must be type ' \
+ f'int or float, got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, 'The level should be in range ' \
+ f'[0,{_MAX_LEVEL}], got {level}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must ' \
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), 'all ' \
+ 'elements of img_fill_val should between range [0,255].' \
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability of shear should be in ' \
+ f'range [0,1]. got {prob}.'
+ assert direction in ('horizontal', 'vertical'), 'direction must ' \
+ f'in be either "horizontal" or "vertical". got {direction}.'
+ assert isinstance(max_shear_magnitude, float), 'max_shear_magnitude ' \
+ f'should be type float. got {type(max_shear_magnitude)}.'
+ assert 0. <= max_shear_magnitude <= 1., 'Defaultly ' \
+ 'max_shear_magnitude should be in range [0,1]. ' \
+ f'got {max_shear_magnitude}.'
+ self.level = level
+ self.magnitude = level_to_value(level, max_shear_magnitude)
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.direction = direction
+ self.max_shear_magnitude = max_shear_magnitude
+ self.random_negative_prob = random_negative_prob
+ self.interpolation = interpolation
+
+ def _shear_img(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ interpolation='bilinear'):
+ """Shear the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ magnitude (int | float): The magnitude used for shear.
+ direction (str): The direction for shear, either "horizontal"
+ or "vertical".
+ interpolation (str): Same as in :func:`mmcv.imshear`.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ img_sheared = mmcv.imshear(
+ img,
+ magnitude,
+ direction,
+ border_value=self.img_fill_val,
+ interpolation=interpolation)
+ results[key] = img_sheared.astype(img.dtype)
+
+ def _shear_bboxes(self, results, magnitude):
+ """Shear the bboxes."""
+ h, w, c = results['img_shape']
+ if self.direction == 'horizontal':
+ shear_matrix = np.stack([[1, magnitude],
+ [0, 1]]).astype(np.float32) # [2, 2]
+ else:
+ shear_matrix = np.stack([[1, 0], [magnitude,
+ 1]]).astype(np.float32)
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_box, 1]
+ coordinates = coordinates[..., 0].transpose(
+ (2, 1, 0)).astype(np.float32) # [nb_box, 2, 4]
+ new_coords = np.matmul(shear_matrix[None, :, :],
+ coordinates) # [nb_box, 2, 4]
+ min_x = np.min(new_coords[:, 0, :], axis=-1)
+ min_y = np.min(new_coords[:, 1, :], axis=-1)
+ max_x = np.max(new_coords[:, 0, :], axis=-1)
+ max_y = np.max(new_coords[:, 1, :], axis=-1)
+ min_x = np.clip(min_x, a_min=0, a_max=w)
+ min_y = np.clip(min_y, a_min=0, a_max=h)
+ max_x = np.clip(max_x, a_min=min_x, a_max=w)
+ max_y = np.clip(max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _shear_masks(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=0,
+ interpolation='bilinear'):
+ """Shear the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.shear((h, w),
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation)
+
+ def _shear_seg(self,
+ results,
+ magnitude,
+ direction='horizontal',
+ fill_val=255,
+ interpolation='bilinear'):
+ """Shear the segmentation maps."""
+ for key in results.get('seg_fields', []):
+ seg = results[key]
+ results[key] = mmcv.imshear(
+ seg,
+ magnitude,
+ direction,
+ border_value=fill_val,
+ interpolation=interpolation).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after shear
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to shear images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Sheared results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ magnitude = random_negative(self.magnitude, self.random_negative_prob)
+ self._shear_img(results, magnitude, self.direction, self.interpolation)
+ self._shear_bboxes(results, magnitude)
+ # fill_val set to 0 for background of mask.
+ self._shear_masks(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=0,
+ interpolation=self.interpolation)
+ self._shear_seg(
+ results,
+ magnitude,
+ self.direction,
+ fill_val=self.seg_ignore_label,
+ interpolation=self.interpolation)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'direction={self.direction}, '
+ repr_str += f'max_shear_magnitude={self.max_shear_magnitude}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob}, '
+ repr_str += f'interpolation={self.interpolation})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Rotate(object):
+ """Apply Rotate Transformation to image (and its corresponding bbox, mask,
+ segmentation).
+
+ Args:
+ level (int | float): The level should be in range (0,_MAX_LEVEL].
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ center (int | float | tuple[float]): Center point (w, h) of the
+ rotation in the source image. If None, the center of the
+ image will be used. Same in ``mmcv.imrotate``.
+ img_fill_val (int | float | tuple): The fill value for image border.
+ If float, the same value will be used for all the three
+ channels of image. If tuple, the should be 3 elements (e.g.
+ equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ prob (float): The probability for perform transformation and
+ should be in range 0 to 1.
+ max_rotate_angle (int | float): The maximum angles for rotate
+ transformation.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ """
+
+ def __init__(self,
+ level,
+ scale=1,
+ center=None,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ prob=0.5,
+ max_rotate_angle=30,
+ random_negative_prob=0.5):
+ assert isinstance(level, (int, float)), \
+ f'The level must be type int or float. got {type(level)}.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ f'The level should be in range (0,{_MAX_LEVEL}]. got {level}.'
+ assert isinstance(scale, (int, float)), \
+ f'The scale must be type int or float. got type {type(scale)}.'
+ if isinstance(center, (int, float)):
+ center = (center, center)
+ elif isinstance(center, tuple):
+ assert len(center) == 2, 'center with type tuple must have '\
+ f'2 elements. got {len(center)} elements.'
+ else:
+ assert center is None, 'center must be None or type int, '\
+ f'float or tuple, got type {type(center)}.'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, 'img_fill_val as tuple must '\
+ f'have 3 elements. got {len(img_fill_val)}.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError(
+ 'img_fill_val must be float or tuple with 3 elements.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255]. '\
+ f'got {img_fill_val}.'
+ assert 0 <= prob <= 1.0, 'The probability should be in range [0,1]. '\
+ 'got {prob}.'
+ assert isinstance(max_rotate_angle, (int, float)), 'max_rotate_angle '\
+ f'should be type int or float. got type {type(max_rotate_angle)}.'
+ self.level = level
+ self.scale = scale
+ # Rotation angle in degrees. Positive values mean
+ # clockwise rotation.
+ self.angle = level_to_value(level, max_rotate_angle)
+ self.center = center
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+ self.max_rotate_angle = max_rotate_angle
+ self.random_negative_prob = random_negative_prob
+
+ def _rotate_img(self, results, angle, center=None, scale=1.0):
+ """Rotate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ angle (float): Rotation angle in degrees, positive values
+ mean clockwise rotation. Same in ``mmcv.imrotate``.
+ center (tuple[float], optional): Center point (w, h) of the
+ rotation. Same in ``mmcv.imrotate``.
+ scale (int | float): Isotropic scale factor. Same in
+ ``mmcv.imrotate``.
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ img_rotated = mmcv.imrotate(
+ img, angle, center, scale, border_value=self.img_fill_val)
+ results[key] = img_rotated.astype(img.dtype)
+
+ def _rotate_bboxes(self, results, rotate_matrix):
+ """Rotate the bboxes."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ coordinates = np.stack([[min_x, min_y], [max_x, min_y],
+ [min_x, max_y],
+ [max_x, max_y]]) # [4, 2, nb_bbox, 1]
+ # pad 1 to convert from format [x, y] to homogeneous
+ # coordinates format [x, y, 1]
+ coordinates = np.concatenate(
+ (coordinates,
+ np.ones((4, 1, coordinates.shape[2], 1), coordinates.dtype)),
+ axis=1) # [4, 3, nb_bbox, 1]
+ coordinates = coordinates.transpose(
+ (2, 0, 1, 3)) # [nb_bbox, 4, 3, 1]
+ rotated_coords = np.matmul(rotate_matrix,
+ coordinates) # [nb_bbox, 4, 2, 1]
+ rotated_coords = rotated_coords[..., 0] # [nb_bbox, 4, 2]
+ min_x, min_y = np.min(
+ rotated_coords[:, :, 0], axis=1), np.min(
+ rotated_coords[:, :, 1], axis=1)
+ max_x, max_y = np.max(
+ rotated_coords[:, :, 0], axis=1), np.max(
+ rotated_coords[:, :, 1], axis=1)
+ min_x, min_y = np.clip(
+ min_x, a_min=0, a_max=w), np.clip(
+ min_y, a_min=0, a_max=h)
+ max_x, max_y = np.clip(
+ max_x, a_min=min_x, a_max=w), np.clip(
+ max_y, a_min=min_y, a_max=h)
+ results[key] = np.stack([min_x, min_y, max_x, max_y],
+ axis=-1).astype(results[key].dtype)
+
+ def _rotate_masks(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=0):
+ """Rotate the masks."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.rotate((h, w), angle, center, scale, fill_val)
+
+ def _rotate_seg(self,
+ results,
+ angle,
+ center=None,
+ scale=1.0,
+ fill_val=255):
+ """Rotate the segmentation map."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imrotate(
+ seg, angle, center, scale,
+ border_value=fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_bbox_size=0):
+ """Filter bboxes and corresponding masks too small after rotate
+ augmentation."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_bbox_size) & (bbox_h > min_bbox_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+
+ def __call__(self, results):
+ """Call function to rotate images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ h, w = results['img'].shape[:2]
+ center = self.center
+ if center is None:
+ center = ((w - 1) * 0.5, (h - 1) * 0.5)
+ angle = random_negative(self.angle, self.random_negative_prob)
+ self._rotate_img(results, angle, center, self.scale)
+ rotate_matrix = cv2.getRotationMatrix2D(center, -angle, self.scale)
+ self._rotate_bboxes(results, rotate_matrix)
+ self._rotate_masks(results, angle, center, self.scale, fill_val=0)
+ self._rotate_seg(
+ results, angle, center, self.scale, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'scale={self.scale}, '
+ repr_str += f'center={self.center}, '
+ repr_str += f'img_fill_val={self.img_fill_val}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label}, '
+ repr_str += f'prob={self.prob}, '
+ repr_str += f'max_rotate_angle={self.max_rotate_angle}, '
+ repr_str += f'random_negative_prob={self.random_negative_prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Translate(object):
+ """Translate the images, bboxes, masks and segmentation maps horizontally
+ or vertically.
+
+ Args:
+ level (int | float): The level for Translate and should be in
+ range [0,_MAX_LEVEL].
+ prob (float): The probability for performing translation and
+ should be in range [0, 1].
+ img_fill_val (int | float | tuple): The filled value for image
+ border. If float, the same fill value will be used for all
+ the three channels of image. If tuple, the should be 3
+ elements (e.g. equals the number of channels for image).
+ seg_ignore_label (int): The fill value used for segmentation map.
+ Note this value must equals ``ignore_label`` in ``semantic_head``
+ of the corresponding config. Default 255.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ max_translate_offset (int | float): The maximum pixel's offset for
+ Translate.
+ random_negative_prob (float): The probability that turns the
+ offset negative.
+ min_size (int | float): The minimum pixel for filtering
+ invalid bboxes after the translation.
+ """
+
+ def __init__(self,
+ level,
+ prob=0.5,
+ img_fill_val=128,
+ seg_ignore_label=255,
+ direction='horizontal',
+ max_translate_offset=250.,
+ random_negative_prob=0.5,
+ min_size=0):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level used for calculating Translate\'s offset should be ' \
+ 'in range [0,_MAX_LEVEL]'
+ assert 0 <= prob <= 1.0, \
+ 'The probability of translation should be in range [0, 1].'
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, \
+ 'img_fill_val as tuple must have 3 elements.'
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError('img_fill_val must be type float or tuple.')
+ assert np.all([0 <= val <= 255 for val in img_fill_val]), \
+ 'all elements of img_fill_val should between range [0,255].'
+ assert direction in ('horizontal', 'vertical'), \
+ 'direction should be "horizontal" or "vertical".'
+ assert isinstance(max_translate_offset, (int, float)), \
+ 'The max_translate_offset must be type int or float.'
+ # the offset used for translation
+ self.offset = int(level_to_value(level, max_translate_offset))
+ self.level = level
+ self.prob = prob
+ self.img_fill_val = img_fill_val
+ self.seg_ignore_label = seg_ignore_label
+ self.direction = direction
+ self.max_translate_offset = max_translate_offset
+ self.random_negative_prob = random_negative_prob
+ self.min_size = min_size
+
+ def _translate_img(self, results, offset, direction='horizontal'):
+ """Translate the image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ offset (int | float): The offset for translate.
+ direction (str): The translate direction, either "horizontal"
+ or "vertical".
+ """
+ for key in results.get('img_fields', ['img']):
+ img = results[key].copy()
+ results[key] = mmcv.imtranslate(
+ img, offset, direction, self.img_fill_val).astype(img.dtype)
+
+ def _translate_bboxes(self, results, offset):
+ """Shift bboxes horizontally or vertically, according to offset."""
+ h, w, c = results['img_shape']
+ for key in results.get('bbox_fields', []):
+ min_x, min_y, max_x, max_y = np.split(
+ results[key], results[key].shape[-1], axis=-1)
+ if self.direction == 'horizontal':
+ min_x = np.maximum(0, min_x + offset)
+ max_x = np.minimum(w, max_x + offset)
+ elif self.direction == 'vertical':
+ min_y = np.maximum(0, min_y + offset)
+ max_y = np.minimum(h, max_y + offset)
+
+ # the boxes translated outside of image will be filtered along with
+ # the corresponding masks, by invoking ``_filter_invalid``.
+ results[key] = np.concatenate([min_x, min_y, max_x, max_y],
+ axis=-1)
+
+ def _translate_masks(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=0):
+ """Translate masks horizontally or vertically."""
+ h, w, c = results['img_shape']
+ for key in results.get('mask_fields', []):
+ masks = results[key]
+ results[key] = masks.translate((h, w), offset, direction, fill_val)
+
+ def _translate_seg(self,
+ results,
+ offset,
+ direction='horizontal',
+ fill_val=255):
+ """Translate segmentation maps horizontally or vertically."""
+ for key in results.get('seg_fields', []):
+ seg = results[key].copy()
+ results[key] = mmcv.imtranslate(seg, offset, direction,
+ fill_val).astype(seg.dtype)
+
+ def _filter_invalid(self, results, min_size=0):
+ """Filter bboxes and masks too small or translated out of image."""
+ bbox2label, bbox2mask, _ = bbox2fields()
+ for key in results.get('bbox_fields', []):
+ bbox_w = results[key][:, 2] - results[key][:, 0]
+ bbox_h = results[key][:, 3] - results[key][:, 1]
+ valid_inds = (bbox_w > min_size) & (bbox_h > min_size)
+ valid_inds = np.nonzero(valid_inds)[0]
+ results[key] = results[key][valid_inds]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][valid_inds]
+ return results
+
+ def __call__(self, results):
+ """Call function to translate images, bounding boxes, masks and
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Translated results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ offset = random_negative(self.offset, self.random_negative_prob)
+ self._translate_img(results, offset, self.direction)
+ self._translate_bboxes(results, offset)
+ # fill_val defaultly 0 for BitmapMasks and None for PolygonMasks.
+ self._translate_masks(results, offset, self.direction)
+ # fill_val set to ``seg_ignore_label`` for the ignored value
+ # of segmentation map.
+ self._translate_seg(
+ results, offset, self.direction, fill_val=self.seg_ignore_label)
+ self._filter_invalid(results, min_size=self.min_size)
+ return results
+
+
+@PIPELINES.register_module()
+class ColorTransform(object):
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_color_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get('img_fields', ['img']):
+ # NOTE defaultly the image should be BGR format
+ img = results[key]
+ results[key] = mmcv.adjust_color(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_color_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class EqualizeTransform(object):
+ """Apply Equalize transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ prob (float): The probability for performing Equalize transformation.
+ """
+
+ def __init__(self, prob=0.5):
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.prob = prob
+
+ def _imequalize(self, results):
+ """Equalizes the histogram of one image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.imequalize(img).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Equalize transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._imequalize(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(prob={self.prob})'
+
+
+@PIPELINES.register_module()
+class BrightnessTransform(object):
+ """Apply Brightness transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Brightness transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_brightness_img(self, results, factor=1.0):
+ """Adjust the brightness of image."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_brightness(img,
+ factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Brightness transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_brightness_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class ContrastTransform(object):
+ """Apply Contrast transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Contrast transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ assert isinstance(level, (int, float)), \
+ 'The level must be type int or float.'
+ assert 0 <= level <= _MAX_LEVEL, \
+ 'The level should be in range [0,_MAX_LEVEL].'
+ assert 0 <= prob <= 1.0, \
+ 'The probability should be in range [0,1].'
+ self.level = level
+ self.prob = prob
+ self.factor = enhance_level_to_value(level)
+
+ def _adjust_contrast_img(self, results, factor=1.0):
+ """Adjust the image contrast."""
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ results[key] = mmcv.adjust_contrast(img, factor).astype(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Contrast transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.rand() > self.prob:
+ return results
+ self._adjust_contrast_img(results, self.factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(level={self.level}, '
+ repr_str += f'prob={self.prob})'
+ return repr_str
diff --git a/walt/datasets/pipelines/compose.py b/walt/datasets/pipelines/compose.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7a269832bd13983197daf1001b397a9c416c762
--- /dev/null
+++ b/walt/datasets/pipelines/compose.py
@@ -0,0 +1,52 @@
+import collections
+
+from mmcv.utils import build_from_cfg
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class Compose(object):
+ """Compose multiple transforms sequentially.
+
+ Args:
+ transforms (Sequence[dict | callable]): Sequence of transform object or
+ config dict to be composed.
+ """
+
+ def __init__(self, transforms):
+ assert isinstance(transforms, collections.abc.Sequence)
+ self.transforms = []
+ for transform in transforms:
+ if isinstance(transform, dict):
+ transform = build_from_cfg(transform, PIPELINES)
+ self.transforms.append(transform)
+ elif callable(transform):
+ self.transforms.append(transform)
+ else:
+ raise TypeError('transform must be callable or a dict')
+
+ def __call__(self, data):
+ """Call function to apply transforms sequentially.
+
+ Args:
+ data (dict): A result dict contains the data to transform.
+
+ Returns:
+ dict: Transformed data.
+ """
+
+ for t in self.transforms:
+ #print(data)
+ data = t(data)
+ if data is None:
+ return None
+ return data
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + '('
+ for t in self.transforms:
+ format_string += '\n'
+ format_string += f' {t}'
+ format_string += '\n)'
+ return format_string
diff --git a/walt/datasets/pipelines/formating.py b/walt/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5ee540cd37f070fa47231cc569e97850655ad1a
--- /dev/null
+++ b/walt/datasets/pipelines/formating.py
@@ -0,0 +1,366 @@
+from collections.abc import Sequence
+
+import mmcv
+import numpy as np
+import torch
+from mmcv.parallel import DataContainer as DC
+
+from ..builder import PIPELINES
+
+
+def to_tensor(data):
+ """Convert objects of various python types to :obj:`torch.Tensor`.
+
+ Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`,
+ :class:`Sequence`, :class:`int` and :class:`float`.
+
+ Args:
+ data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to
+ be converted.
+ """
+
+ if isinstance(data, torch.Tensor):
+ return data
+ elif isinstance(data, np.ndarray):
+ return torch.from_numpy(data)
+ elif isinstance(data, Sequence) and not mmcv.is_str(data):
+ return torch.tensor(data)
+ elif isinstance(data, int):
+ return torch.LongTensor([data])
+ elif isinstance(data, float):
+ return torch.FloatTensor([data])
+ else:
+ raise TypeError(f'type {type(data)} cannot be converted to tensor.')
+
+
+@PIPELINES.register_module()
+class ToTensor(object):
+ """Convert some results to :obj:`torch.Tensor` by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys that need to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert data in results to :obj:`torch.Tensor`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted
+ to :obj:`torch.Tensor`.
+ """
+ for key in self.keys:
+ results[key] = to_tensor(results[key])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class ImageToTensor(object):
+ """Convert image to :obj:`torch.Tensor` by given keys.
+
+ The dimension order of input image is (H, W, C). The pipeline will convert
+ it to (C, H, W). If only 2 dimension (H, W) is given, the output would be
+ (1, H, W).
+
+ Args:
+ keys (Sequence[str]): Key of images to be converted to Tensor.
+ """
+
+ def __init__(self, keys):
+ self.keys = keys
+
+ def __call__(self, results):
+ """Call function to convert image in results to :obj:`torch.Tensor` and
+ transpose the channel order.
+
+ Args:
+ results (dict): Result dict contains the image data to convert.
+
+ Returns:
+ dict: The result dict contains the image converted
+ to :obj:`torch.Tensor` and transposed to (C, H, W) order.
+ """
+ for key in self.keys:
+ img = results[key]
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ results[key] = to_tensor(img.transpose(2, 0, 1))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(keys={self.keys})'
+
+
+@PIPELINES.register_module()
+class Transpose(object):
+ """Transpose some results by given keys.
+
+ Args:
+ keys (Sequence[str]): Keys of results to be transposed.
+ order (Sequence[int]): Order of transpose.
+ """
+
+ def __init__(self, keys, order):
+ self.keys = keys
+ self.order = order
+
+ def __call__(self, results):
+ """Call function to transpose the channel order of data in results.
+
+ Args:
+ results (dict): Result dict contains the data to transpose.
+
+ Returns:
+ dict: The result dict contains the data transposed to \
+ ``self.order``.
+ """
+ for key in self.keys:
+ results[key] = results[key].transpose(self.order)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, order={self.order})'
+
+
+@PIPELINES.register_module()
+class ToDataContainer(object):
+ """Convert results to :obj:`mmcv.DataContainer` by given fields.
+
+ Args:
+ fields (Sequence[dict]): Each field is a dict like
+ ``dict(key='xxx', **kwargs)``. The ``key`` in result will
+ be converted to :obj:`mmcv.DataContainer` with ``**kwargs``.
+ Default: ``(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))``.
+ """
+
+ def __init__(self,
+ fields=(dict(key='img', stack=True), dict(key='gt_bboxes'),
+ dict(key='gt_labels'))):
+ self.fields = fields
+
+ def __call__(self, results):
+ """Call function to convert data in results to
+ :obj:`mmcv.DataContainer`.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data converted to \
+ :obj:`mmcv.DataContainer`.
+ """
+
+ for field in self.fields:
+ field = field.copy()
+ key = field.pop('key')
+ results[key] = DC(results[key], **field)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(fields={self.fields})'
+
+
+@PIPELINES.register_module()
+class DefaultFormatBundle(object):
+ """Default formatting bundle.
+
+ It simplifies the pipeline of formatting common fields, including "img",
+ "proposals", "gt_bboxes", "gt_labels", "gt_masks" and "gt_semantic_seg".
+ These fields are formatted as follows.
+
+ - img: (1)transpose, (2)to tensor, (3)to DataContainer (stack=True)
+ - proposals: (1)to tensor, (2)to DataContainer
+ - gt_bboxes: (1)to tensor, (2)to DataContainer
+ - gt_bboxes_ignore: (1)to tensor, (2)to DataContainer
+ - gt_labels: (1)to tensor, (2)to DataContainer
+ - gt_masks: (1)to tensor, (2)to DataContainer (cpu_only=True)
+ - gt_semantic_seg: (1)unsqueeze dim-0 (2)to tensor, \
+ (3)to DataContainer (stack=True)
+ """
+
+ def __call__(self, results):
+ """Call function to transform and format common fields in results.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ dict: The result dict contains the data that is formatted with \
+ default bundle.
+ """
+
+ if 'img' in results:
+ img = results['img']
+ # add default meta keys
+ results = self._add_default_meta_keys(results)
+ if len(img.shape) < 3:
+ img = np.expand_dims(img, -1)
+ img = np.ascontiguousarray(img.transpose(2, 0, 1))
+ results['img'] = DC(to_tensor(img), stack=True)
+ for key in ['proposals', 'gt_bboxes', 'gt_bboxes_ignore', 'gt_labels','gt_bboxes_3d', 'gt_bboxes_3d_proj']:
+ if key not in results:
+ continue
+ results[key] = DC(to_tensor(results[key]))
+ if 'gt_bboxes_3d' in results:
+ results['gt_bboxes_3d'] = DC(results['gt_bboxes_3d'], cpu_only=True)
+ if 'gt_masks' in results:
+ results['gt_masks'] = DC(results['gt_masks'], cpu_only=True)
+ if 'gt_semantic_seg' in results:
+ results['gt_semantic_seg'] = DC(
+ to_tensor(results['gt_semantic_seg'][None, ...]), stack=True)
+ return results
+
+ def _add_default_meta_keys(self, results):
+ """Add default meta keys.
+
+ We set default meta keys including `pad_shape`, `scale_factor` and
+ `img_norm_cfg` to avoid the case where no `Resize`, `Normalize` and
+ `Pad` are implemented during the whole pipeline.
+
+ Args:
+ results (dict): Result dict contains the data to convert.
+
+ Returns:
+ results (dict): Updated result dict contains the data to convert.
+ """
+ img = results['img']
+ results.setdefault('pad_shape', img.shape)
+ results.setdefault('scale_factor', 1.0)
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results.setdefault(
+ 'img_norm_cfg',
+ dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False))
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__
+
+
+@PIPELINES.register_module()
+class Collect(object):
+ """Collect data from the loader relevant to the specific task.
+
+ This is usually the last stage of the data loader pipeline. Typically keys
+ is set to some subset of "img", "proposals", "gt_bboxes",
+ "gt_bboxes_ignore", "gt_labels", and/or "gt_masks".
+
+ The "img_meta" item is always populated. The contents of the "img_meta"
+ dictionary depends on "meta_keys". By default this includes:
+
+ - "img_shape": shape of the image input to the network as a tuple \
+ (h, w, c). Note that images may be zero padded on the \
+ bottom/right if the batch tensor is larger than this shape.
+
+ - "scale_factor": a float indicating the preprocessing scale
+
+ - "flip": a boolean indicating if image flip transform was used
+
+ - "filename": path to the image file
+
+ - "ori_shape": original shape of the image as a tuple (h, w, c)
+
+ - "pad_shape": image shape after padding
+
+ - "img_norm_cfg": a dict of normalization information:
+
+ - mean - per channel mean subtraction
+ - std - per channel std divisor
+ - to_rgb - bool indicating if bgr was converted to rgb
+
+ Args:
+ keys (Sequence[str]): Keys of results to be collected in ``data``.
+ meta_keys (Sequence[str], optional): Meta keys to be converted to
+ ``mmcv.DataContainer`` and collected in ``data[img_metas]``.
+ Default: ``('filename', 'ori_filename', 'ori_shape', 'img_shape',
+ 'pad_shape', 'scale_factor', 'flip', 'flip_direction',
+ 'img_norm_cfg')``
+ """
+
+ def __init__(self,
+ keys,
+ meta_keys=('filename', 'ori_filename', 'ori_shape',
+ 'img_shape', 'pad_shape', 'scale_factor', 'flip',
+ 'flip_direction', 'img_norm_cfg')):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ """Call function to collect keys in results. The keys in ``meta_keys``
+ will be converted to :obj:mmcv.DataContainer.
+
+ Args:
+ results (dict): Result dict contains the data to collect.
+
+ Returns:
+ dict: The result dict contains the following keys
+
+ - keys in``self.keys``
+ - ``img_metas``
+ """
+
+ data = {}
+ img_meta = {}
+ for key in self.meta_keys:
+ img_meta[key] = results[key]
+ data['img_metas'] = DC(img_meta, cpu_only=True)
+ for key in self.keys:
+ data[key] = results[key]
+ return data
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(keys={self.keys}, meta_keys={self.meta_keys})'
+
+
+@PIPELINES.register_module()
+class WrapFieldsToLists(object):
+ """Wrap fields of the data dictionary into lists for evaluation.
+
+ This class can be used as a last step of a test or validation
+ pipeline for single image evaluation or inference.
+
+ Example:
+ >>> test_pipeline = [
+ >>> dict(type='LoadImageFromFile'),
+ >>> dict(type='Normalize',
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ to_rgb=True),
+ >>> dict(type='Pad', size_divisor=32),
+ >>> dict(type='ImageToTensor', keys=['img']),
+ >>> dict(type='Collect', keys=['img']),
+ >>> dict(type='WrapFieldsToLists')
+ >>> ]
+ """
+
+ def __call__(self, results):
+ """Call function to wrap fields into lists.
+
+ Args:
+ results (dict): Result dict contains the data to wrap.
+
+ Returns:
+ dict: The result dict where value of ``self.keys`` are wrapped \
+ into list.
+ """
+
+ # Wrap dict fields into lists
+ for key, val in results.items():
+ results[key] = [val]
+ return results
+
+ def __repr__(self):
+ return f'{self.__class__.__name__}()'
diff --git a/walt/datasets/pipelines/instaboost.py b/walt/datasets/pipelines/instaboost.py
new file mode 100644
index 0000000000000000000000000000000000000000..38b6819f60587a6e0c0f6d57bfda32bb3a7a4267
--- /dev/null
+++ b/walt/datasets/pipelines/instaboost.py
@@ -0,0 +1,98 @@
+import numpy as np
+
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class InstaBoost(object):
+ r"""Data augmentation method in `InstaBoost: Boosting Instance
+ Segmentation Via Probability Map Guided Copy-Pasting
+ `_.
+
+ Refer to https://github.com/GothicAi/Instaboost for implementation details.
+ """
+
+ def __init__(self,
+ action_candidate=('normal', 'horizontal', 'skip'),
+ action_prob=(1, 0, 0),
+ scale=(0.8, 1.2),
+ dx=15,
+ dy=15,
+ theta=(-1, 1),
+ color_prob=0.5,
+ hflag=False,
+ aug_ratio=0.5):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError(
+ 'Please run "pip install instaboostfast" '
+ 'to install instaboostfast first for instaboost augmentation.')
+ self.cfg = instaboost.InstaBoostConfig(action_candidate, action_prob,
+ scale, dx, dy, theta,
+ color_prob, hflag)
+ self.aug_ratio = aug_ratio
+
+ def _load_anns(self, results):
+ labels = results['ann_info']['labels']
+ masks = results['ann_info']['masks']
+ bboxes = results['ann_info']['bboxes']
+ n = len(labels)
+
+ anns = []
+ for i in range(n):
+ label = labels[i]
+ bbox = bboxes[i]
+ mask = masks[i]
+ x1, y1, x2, y2 = bbox
+ # assert (x2 - x1) >= 1 and (y2 - y1) >= 1
+ bbox = [x1, y1, x2 - x1, y2 - y1]
+ anns.append({
+ 'category_id': label,
+ 'segmentation': mask,
+ 'bbox': bbox
+ })
+
+ return anns
+
+ def _parse_anns(self, results, anns, img):
+ gt_bboxes = []
+ gt_labels = []
+ gt_masks_ann = []
+ for ann in anns:
+ x1, y1, w, h = ann['bbox']
+ # TODO: more essential bug need to be fixed in instaboost
+ if w <= 0 or h <= 0:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ gt_bboxes.append(bbox)
+ gt_labels.append(ann['category_id'])
+ gt_masks_ann.append(ann['segmentation'])
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ results['ann_info']['labels'] = gt_labels
+ results['ann_info']['bboxes'] = gt_bboxes
+ results['ann_info']['masks'] = gt_masks_ann
+ results['img'] = img
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ orig_type = img.dtype
+ anns = self._load_anns(results)
+ if np.random.choice([0, 1], p=[1 - self.aug_ratio, self.aug_ratio]):
+ try:
+ import instaboostfast as instaboost
+ except ImportError:
+ raise ImportError('Please run "pip install instaboostfast" '
+ 'to install instaboostfast first.')
+ anns, img = instaboost.get_new_data(
+ anns, img.astype(np.uint8), self.cfg, background=None)
+
+ results = self._parse_anns(results, anns, img.astype(orig_type))
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(cfg={self.cfg}, aug_ratio={self.aug_ratio})'
+ return repr_str
diff --git a/walt/datasets/pipelines/loading.py b/walt/datasets/pipelines/loading.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0369aadc3c4b76ab87db608fc9e31e0040f583f
--- /dev/null
+++ b/walt/datasets/pipelines/loading.py
@@ -0,0 +1,465 @@
+import os.path as osp
+
+import mmcv
+import numpy as np
+import pycocotools.mask as maskUtils
+
+from mmdet.core import BitmapMasks, PolygonMasks
+from ..builder import PIPELINES
+
+
+@PIPELINES.register_module()
+class LoadImageFromFile(object):
+ """Load an image from file.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename"). Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='color',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load image and get image meta information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = osp.join(results['img_prefix'],
+ results['img_info']['filename'])
+ else:
+ filename = results['img_info']['filename']
+
+ img_bytes = self.file_client.get(filename)
+ img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadImageFromWebcam(LoadImageFromFile):
+ """Load an image from webcam.
+
+ Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
+ ``results['img']``.
+ """
+
+ def __call__(self, results):
+ """Call functions to add image meta information.
+
+ Args:
+ results (dict): Result dict with Webcam read image in
+ ``results['img']``.
+
+ Returns:
+ dict: The dict contains loaded image and meta information.
+ """
+
+ img = results['img']
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = None
+ results['ori_filename'] = None
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ results['img_fields'] = ['img']
+ return results
+
+
+@PIPELINES.register_module()
+class LoadMultiChannelImageFromFiles(object):
+ """Load multi-channel images from a list of separate channel files.
+
+ Required keys are "img_prefix" and "img_info" (a dict that must contain the
+ key "filename", which is expected to be a list of filenames).
+ Added or updated keys are "filename", "img", "img_shape",
+ "ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
+ "scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
+
+ Args:
+ to_float32 (bool): Whether to convert the loaded image to a float32
+ numpy array. If set to False, the loaded image is an uint8 array.
+ Defaults to False.
+ color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
+ Defaults to 'color'.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ to_float32=False,
+ color_type='unchanged',
+ file_client_args=dict(backend='disk')):
+ self.to_float32 = to_float32
+ self.color_type = color_type
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def __call__(self, results):
+ """Call functions to load multiple images and get images meta
+ information.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded images and meta information.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ if results['img_prefix'] is not None:
+ filename = [
+ osp.join(results['img_prefix'], fname)
+ for fname in results['img_info']['filename']
+ ]
+ else:
+ filename = results['img_info']['filename']
+
+ img = []
+ for name in filename:
+ img_bytes = self.file_client.get(name)
+ img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
+ img = np.stack(img, axis=-1)
+ if self.to_float32:
+ img = img.astype(np.float32)
+
+ results['filename'] = filename
+ results['ori_filename'] = results['img_info']['filename']
+ results['img'] = img
+ results['img_shape'] = img.shape
+ results['ori_shape'] = img.shape
+ # Set initial values for default meta_keys
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = 1.0
+ num_channels = 1 if len(img.shape) < 3 else img.shape[2]
+ results['img_norm_cfg'] = dict(
+ mean=np.zeros(num_channels, dtype=np.float32),
+ std=np.ones(num_channels, dtype=np.float32),
+ to_rgb=False)
+ return results
+
+ def __repr__(self):
+ repr_str = (f'{self.__class__.__name__}('
+ f'to_float32={self.to_float32}, '
+ f"color_type='{self.color_type}', "
+ f'file_client_args={self.file_client_args})')
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadAnnotations(object):
+ """Load mutiple types of annotations.
+
+ Args:
+ with_bbox (bool): Whether to parse and load the bbox annotation.
+ Default: True.
+ with_label (bool): Whether to parse and load the label annotation.
+ Default: True.
+ with_mask (bool): Whether to parse and load the mask annotation.
+ Default: False.
+ with_seg (bool): Whether to parse and load the semantic segmentation
+ annotation. Default: False.
+ poly2mask (bool): Whether to convert the instance masks from polygons
+ to bitmaps. Default: True.
+ file_client_args (dict): Arguments to instantiate a FileClient.
+ See :class:`mmcv.fileio.FileClient` for details.
+ Defaults to ``dict(backend='disk')``.
+ """
+
+ def __init__(self,
+ with_bbox=True,
+ with_label=True,
+ with_mask=False,
+ with_seg=False,
+ poly2mask=True,
+ file_client_args=dict(backend='disk')):
+ self.with_bbox = with_bbox
+ self.with_label = with_label
+ self.with_mask = with_mask
+ self.with_seg = with_seg
+ self.poly2mask = poly2mask
+ self.file_client_args = file_client_args.copy()
+ self.file_client = None
+
+ def _load_bboxes(self, results):
+ """Private function to load bounding box annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box annotations.
+ """
+
+ ann_info = results['ann_info']
+ results['gt_bboxes'] = ann_info['bboxes'].copy()
+ try:
+ results['gt_bboxes_3d'] = ann_info['bboxes_3d'].copy()
+ results['gt_bboxes_3d_proj'] = ann_info['bboxes_3d_proj'].copy()
+ results['bbox3d_fields'].append('gt_bboxes_3d')
+ results['bbox3d_fields'].append('gt_bboxes_3d_proj')
+ except:
+ print('3d data not loaded')
+
+ gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
+ if gt_bboxes_ignore is not None:
+ results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
+ results['bbox_fields'].append('gt_bboxes_ignore')
+ results['bbox_fields'].append('gt_bboxes')
+ return results
+
+ def _load_labels(self, results):
+ """Private function to load label annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded label annotations.
+ """
+
+ results['gt_labels'] = results['ann_info']['labels'].copy()
+ return results
+
+ def _poly2mask(self, mask_ann, img_h, img_w):
+ """Private function to convert masks represented with polygon to
+ bitmaps.
+
+ Args:
+ mask_ann (list | dict): Polygon mask annotation input.
+ img_h (int): The height of output mask.
+ img_w (int): The width of output mask.
+
+ Returns:
+ numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
+ """
+
+ if isinstance(mask_ann, list):
+ # polygon -- a single object might consist of multiple parts
+ # we merge all parts into one mask rle code
+ rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ rle = maskUtils.merge(rles)
+ elif isinstance(mask_ann['counts'], list):
+ # uncompressed RLE
+ rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
+ else:
+ # rle
+ rle = mask_ann
+ mask = maskUtils.decode(rle)
+ return mask
+
+ def process_polygons(self, polygons):
+ """Convert polygons to list of ndarray and filter invalid polygons.
+
+ Args:
+ polygons (list[list]): Polygons of one instance.
+
+ Returns:
+ list[numpy.ndarray]: Processed polygons.
+ """
+
+ polygons = [np.array(p) for p in polygons]
+ valid_polygons = []
+ for polygon in polygons:
+ if len(polygon) % 2 == 0 and len(polygon) >= 6:
+ valid_polygons.append(polygon)
+ return valid_polygons
+
+ def _load_masks(self, results):
+ """Private function to load mask annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded mask annotations.
+ If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
+ :obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
+ """
+
+ h, w = results['img_info']['height'], results['img_info']['width']
+ gt_masks = results['ann_info']['masks']
+ if self.poly2mask:
+ gt_masks = BitmapMasks(
+ [self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
+ else:
+ gt_masks = PolygonMasks(
+ [self.process_polygons(polygons) for polygons in gt_masks], h,
+ w)
+ results['gt_masks'] = gt_masks
+ results['mask_fields'].append('gt_masks')
+ return results
+
+ def _load_semantic_seg(self, results):
+ """Private function to load semantic segmentation annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: The dict contains loaded semantic segmentation annotations.
+ """
+
+ if self.file_client is None:
+ self.file_client = mmcv.FileClient(**self.file_client_args)
+
+ filename = osp.join(results['seg_prefix'],
+ results['ann_info']['seg_map'])
+ img_bytes = self.file_client.get(filename)
+ results['gt_semantic_seg'] = mmcv.imfrombytes(
+ img_bytes, flag='unchanged').squeeze()
+ results['seg_fields'].append('gt_semantic_seg')
+ return results
+
+ def __call__(self, results):
+ """Call function to load multiple types annotations.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded bounding box, label, mask and
+ semantic segmentation annotations.
+ """
+
+ if self.with_bbox:
+ results = self._load_bboxes(results)
+ if results is None:
+ return None
+ if self.with_label:
+ results = self._load_labels(results)
+ if self.with_mask:
+ results = self._load_masks(results)
+ if self.with_seg:
+ results = self._load_semantic_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(with_bbox={self.with_bbox}, '
+ repr_str += f'with_label={self.with_label}, '
+ repr_str += f'with_mask={self.with_mask}, '
+ repr_str += f'with_seg={self.with_seg}, '
+ repr_str += f'poly2mask={self.poly2mask}, '
+ repr_str += f'poly2mask={self.file_client_args})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class LoadProposals(object):
+ """Load proposal pipeline.
+
+ Required key is "proposals". Updated keys are "proposals", "bbox_fields".
+
+ Args:
+ num_max_proposals (int, optional): Maximum number of proposals to load.
+ If not specified, all proposals will be loaded.
+ """
+
+ def __init__(self, num_max_proposals=None):
+ self.num_max_proposals = num_max_proposals
+
+ def __call__(self, results):
+ """Call function to load proposals from file.
+
+ Args:
+ results (dict): Result dict from :obj:`mmdet.CustomDataset`.
+
+ Returns:
+ dict: The dict contains loaded proposal annotations.
+ """
+
+ proposals = results['proposals']
+ if proposals.shape[1] not in (4, 5):
+ raise AssertionError(
+ 'proposals should have shapes (n, 4) or (n, 5), '
+ f'but found {proposals.shape}')
+ proposals = proposals[:, :4]
+
+ if self.num_max_proposals is not None:
+ proposals = proposals[:self.num_max_proposals]
+
+ if len(proposals) == 0:
+ proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
+ results['proposals'] = proposals
+ results['bbox_fields'].append('proposals')
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + \
+ f'(num_max_proposals={self.num_max_proposals})'
+
+
+@PIPELINES.register_module()
+class FilterAnnotations(object):
+ """Filter invalid annotations.
+
+ Args:
+ min_gt_bbox_wh (tuple[int]): Minimum width and height of ground truth
+ boxes.
+ """
+
+ def __init__(self, min_gt_bbox_wh):
+ # TODO: add more filter options
+ self.min_gt_bbox_wh = min_gt_bbox_wh
+
+ def __call__(self, results):
+ assert 'gt_bboxes' in results
+ gt_bboxes = results['gt_bboxes']
+ w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
+ h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
+ keep = (w > self.min_gt_bbox_wh[0]) & (h > self.min_gt_bbox_wh[1])
+ if not keep.any():
+ return None
+ else:
+ keys = ('gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg')
+ for key in keys:
+ if key in results:
+ results[key] = results[key][keep]
+ return results
diff --git a/walt/datasets/pipelines/test_time_aug.py b/walt/datasets/pipelines/test_time_aug.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6226e040499882c99f15594c66ebf3d07829168
--- /dev/null
+++ b/walt/datasets/pipelines/test_time_aug.py
@@ -0,0 +1,119 @@
+import warnings
+
+import mmcv
+
+from ..builder import PIPELINES
+from .compose import Compose
+
+
+@PIPELINES.register_module()
+class MultiScaleFlipAug(object):
+ """Test-time augmentation with multiple scales and flipping.
+
+ An example configuration is as followed:
+
+ .. code-block::
+
+ img_scale=[(1333, 400), (1333, 800)],
+ flip=True,
+ transforms=[
+ dict(type='Resize', keep_ratio=True),
+ dict(type='RandomFlip'),
+ dict(type='Normalize', **img_norm_cfg),
+ dict(type='Pad', size_divisor=32),
+ dict(type='ImageToTensor', keys=['img']),
+ dict(type='Collect', keys=['img']),
+ ]
+
+ After MultiScaleFLipAug with above configuration, the results are wrapped
+ into lists of the same length as followed:
+
+ .. code-block::
+
+ dict(
+ img=[...],
+ img_shape=[...],
+ scale=[(1333, 400), (1333, 400), (1333, 800), (1333, 800)]
+ flip=[False, True, False, True]
+ ...
+ )
+
+ Args:
+ transforms (list[dict]): Transforms to apply in each augmentation.
+ img_scale (tuple | list[tuple] | None): Images scales for resizing.
+ scale_factor (float | list[float] | None): Scale factors for resizing.
+ flip (bool): Whether apply flip augmentation. Default: False.
+ flip_direction (str | list[str]): Flip augmentation directions,
+ options are "horizontal" and "vertical". If flip_direction is list,
+ multiple flip augmentations will be applied.
+ It has no effect when flip == False. Default: "horizontal".
+ """
+
+ def __init__(self,
+ transforms,
+ img_scale=None,
+ scale_factor=None,
+ flip=False,
+ flip_direction='horizontal'):
+ self.transforms = Compose(transforms)
+ assert (img_scale is None) ^ (scale_factor is None), (
+ 'Must have but only one variable can be setted')
+ if img_scale is not None:
+ self.img_scale = img_scale if isinstance(img_scale,
+ list) else [img_scale]
+ self.scale_key = 'scale'
+ assert mmcv.is_list_of(self.img_scale, tuple)
+ else:
+ self.img_scale = scale_factor if isinstance(
+ scale_factor, list) else [scale_factor]
+ self.scale_key = 'scale_factor'
+
+ self.flip = flip
+ self.flip_direction = flip_direction if isinstance(
+ flip_direction, list) else [flip_direction]
+ assert mmcv.is_list_of(self.flip_direction, str)
+ if not self.flip and self.flip_direction != ['horizontal']:
+ warnings.warn(
+ 'flip_direction has no effect when flip is set to False')
+ if (self.flip
+ and not any([t['type'] == 'RandomFlip' for t in transforms])):
+ warnings.warn(
+ 'flip has no effect when RandomFlip is not in transforms')
+
+ def __call__(self, results):
+ """Call function to apply test time augment transforms on results.
+
+ Args:
+ results (dict): Result dict contains the data to transform.
+
+ Returns:
+ dict[str: list]: The augmented data, where each value is wrapped
+ into a list.
+ """
+
+ aug_data = []
+ flip_args = [(False, None)]
+ if self.flip:
+ flip_args += [(True, direction)
+ for direction in self.flip_direction]
+ for scale in self.img_scale:
+ for flip, direction in flip_args:
+ _results = results.copy()
+ _results[self.scale_key] = scale
+ _results['flip'] = flip
+ _results['flip_direction'] = direction
+ data = self.transforms(_results)
+ aug_data.append(data)
+ # list of dict to dict of list
+ aug_data_dict = {key: [] for key in aug_data[0]}
+ for data in aug_data:
+ for key, val in data.items():
+ aug_data_dict[key].append(val)
+ return aug_data_dict
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(transforms={self.transforms}, '
+ repr_str += f'img_scale={self.img_scale}, flip={self.flip}, '
+ repr_str += f'flip_direction={self.flip_direction})'
+ return repr_str
diff --git a/walt/datasets/pipelines/transforms.py b/walt/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..02fd63f2bfaac64fbf9495f2fe6ffe83dc9371e1
--- /dev/null
+++ b/walt/datasets/pipelines/transforms.py
@@ -0,0 +1,1861 @@
+import copy
+import inspect
+
+import mmcv
+import numpy as np
+from numpy import random
+
+from mmdet.core import PolygonMasks
+from mmdet.core.evaluation.bbox_overlaps import bbox_overlaps
+from ..builder import PIPELINES
+
+try:
+ from imagecorruptions import corrupt
+except ImportError:
+ corrupt = None
+
+try:
+ import albumentations
+ from albumentations import Compose
+except ImportError:
+ albumentations = None
+ Compose = None
+
+
+@PIPELINES.register_module()
+class Resize(object):
+ """Resize images & bbox & mask.
+
+ This transform resizes the input image to some scale. Bboxes and masks are
+ then resized with the same scale factor. If the input dict contains the key
+ "scale", then the scale in the input dict is used, otherwise the specified
+ scale in the init method is used. If the input dict contains the key
+ "scale_factor" (if MultiScaleFlipAug does not give img_scale but
+ scale_factor), the actual scale will be computed by image shape and
+ scale_factor.
+
+ `img_scale` can either be a tuple (single-scale) or a list of tuple
+ (multi-scale). There are 3 multiscale modes:
+
+ - ``ratio_range is not None``: randomly sample a ratio from the ratio \
+ range and multiply it with the image scale.
+ - ``ratio_range is None`` and ``multiscale_mode == "range"``: randomly \
+ sample a scale from the multiscale range.
+ - ``ratio_range is None`` and ``multiscale_mode == "value"``: randomly \
+ sample a scale from multiple scales.
+
+ Args:
+ img_scale (tuple or list[tuple]): Images scales for resizing.
+ multiscale_mode (str): Either "range" or "value".
+ ratio_range (tuple[float]): (min_ratio, max_ratio)
+ keep_ratio (bool): Whether to keep the aspect ratio when resizing the
+ image.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+ backend (str): Image resize backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ override (bool, optional): Whether to override `scale` and
+ `scale_factor` so as to call resize twice. Default False. If True,
+ after the first resizing, the existed `scale` and `scale_factor`
+ will be ignored so the second resizing can be allowed.
+ This option is a work-around for multiple times of resize in DETR.
+ Defaults to False.
+ """
+
+ def __init__(self,
+ img_scale=None,
+ multiscale_mode='range',
+ ratio_range=None,
+ keep_ratio=True,
+ bbox_clip_border=True,
+ backend='cv2',
+ override=False):
+ if img_scale is None:
+ self.img_scale = None
+ else:
+ if isinstance(img_scale, list):
+ self.img_scale = img_scale
+ else:
+ self.img_scale = [img_scale]
+ assert mmcv.is_list_of(self.img_scale, tuple)
+
+ if ratio_range is not None:
+ # mode 1: given a scale and a range of image ratio
+ assert len(self.img_scale) == 1
+ else:
+ # mode 2: given multiple scales or a range of scales
+ assert multiscale_mode in ['value', 'range']
+
+ self.backend = backend
+ self.multiscale_mode = multiscale_mode
+ self.ratio_range = ratio_range
+ self.keep_ratio = keep_ratio
+ # TODO: refactor the override option in Resize
+ self.override = override
+ self.bbox_clip_border = bbox_clip_border
+
+ @staticmethod
+ def random_select(img_scales):
+ """Randomly select an img_scale from given candidates.
+
+ Args:
+ img_scales (list[tuple]): Images scales for selection.
+
+ Returns:
+ (tuple, int): Returns a tuple ``(img_scale, scale_dix)``, \
+ where ``img_scale`` is the selected image scale and \
+ ``scale_idx`` is the selected index in the given candidates.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple)
+ scale_idx = np.random.randint(len(img_scales))
+ img_scale = img_scales[scale_idx]
+ return img_scale, scale_idx
+
+ @staticmethod
+ def random_sample(img_scales):
+ """Randomly sample an img_scale when ``multiscale_mode=='range'``.
+
+ Args:
+ img_scales (list[tuple]): Images scale range for sampling.
+ There must be two tuples in img_scales, which specify the lower
+ and upper bound of image scales.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(img_scale, None)``, where \
+ ``img_scale`` is sampled scale and None is just a placeholder \
+ to be consistent with :func:`random_select`.
+ """
+
+ assert mmcv.is_list_of(img_scales, tuple) and len(img_scales) == 2
+ img_scale_long = [max(s) for s in img_scales]
+ img_scale_short = [min(s) for s in img_scales]
+ long_edge = np.random.randint(
+ min(img_scale_long),
+ max(img_scale_long) + 1)
+ short_edge = np.random.randint(
+ min(img_scale_short),
+ max(img_scale_short) + 1)
+ img_scale = (long_edge, short_edge)
+ return img_scale, None
+
+ @staticmethod
+ def random_sample_ratio(img_scale, ratio_range):
+ """Randomly sample an img_scale when ``ratio_range`` is specified.
+
+ A ratio will be randomly sampled from the range specified by
+ ``ratio_range``. Then it would be multiplied with ``img_scale`` to
+ generate sampled scale.
+
+ Args:
+ img_scale (tuple): Images scale base to multiply with ratio.
+ ratio_range (tuple[float]): The minimum and maximum ratio to scale
+ the ``img_scale``.
+
+ Returns:
+ (tuple, None): Returns a tuple ``(scale, None)``, where \
+ ``scale`` is sampled ratio multiplied with ``img_scale`` and \
+ None is just a placeholder to be consistent with \
+ :func:`random_select`.
+ """
+
+ assert isinstance(img_scale, tuple) and len(img_scale) == 2
+ min_ratio, max_ratio = ratio_range
+ assert min_ratio <= max_ratio
+ ratio = np.random.random_sample() * (max_ratio - min_ratio) + min_ratio
+ scale = int(img_scale[0] * ratio), int(img_scale[1] * ratio)
+ return scale, None
+
+ def _random_scale(self, results):
+ """Randomly sample an img_scale according to ``ratio_range`` and
+ ``multiscale_mode``.
+
+ If ``ratio_range`` is specified, a ratio will be sampled and be
+ multiplied with ``img_scale``.
+ If multiple scales are specified by ``img_scale``, a scale will be
+ sampled according to ``multiscale_mode``.
+ Otherwise, single scale will be used.
+
+ Args:
+ results (dict): Result dict from :obj:`dataset`.
+
+ Returns:
+ dict: Two new keys 'scale` and 'scale_idx` are added into \
+ ``results``, which would be used by subsequent pipelines.
+ """
+
+ if self.ratio_range is not None:
+ scale, scale_idx = self.random_sample_ratio(
+ self.img_scale[0], self.ratio_range)
+ elif len(self.img_scale) == 1:
+ scale, scale_idx = self.img_scale[0], 0
+ elif self.multiscale_mode == 'range':
+ scale, scale_idx = self.random_sample(self.img_scale)
+ elif self.multiscale_mode == 'value':
+ scale, scale_idx = self.random_select(self.img_scale)
+ else:
+ raise NotImplementedError
+
+ results['scale'] = scale
+ results['scale_idx'] = scale_idx
+
+ def _resize_img(self, results):
+ """Resize images with ``results['scale']``."""
+ for key in results.get('img_fields', ['img']):
+ if self.keep_ratio:
+ img, scale_factor = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ backend=self.backend)
+ # the w_scale and h_scale has minor difference
+ # a real fix should be done in the mmcv.imrescale in the future
+ new_h, new_w = img.shape[:2]
+ h, w = results[key].shape[:2]
+ w_scale = new_w / w
+ h_scale = new_h / h
+ else:
+ img, w_scale, h_scale = mmcv.imresize(
+ results[key],
+ results['scale'],
+ return_scale=True,
+ backend=self.backend)
+ results[key] = img
+
+ scale_factor = np.array([w_scale, h_scale, w_scale, h_scale],
+ dtype=np.float32)
+ results['img_shape'] = img.shape
+ # in case that there is no padding
+ results['pad_shape'] = img.shape
+ results['scale_factor'] = scale_factor
+ results['keep_ratio'] = self.keep_ratio
+
+ def _resize_bboxes(self, results):
+ """Resize bounding boxes with ``results['scale_factor']``."""
+ for key in results.get('bbox_fields', []):
+ bboxes = results[key] * results['scale_factor']
+ if self.bbox_clip_border:
+ img_shape = results['img_shape']
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ results[key] = bboxes
+
+ def _resize_bboxes3d(self, results):
+ """Resize bounding boxes with ``results['scale_factor']``."""
+ key = 'gt_bboxes_3d_proj'
+ bboxes3d_proj = results[key][:,:,:2]
+ img_shape = results['img_shape']
+ for i in range(results[key].shape[1]):
+ bboxes3d_proj[:,i,:] = bboxes3d_proj[:,i,:] * results['scale_factor'][:2]
+ if self.bbox_clip_border:
+ bboxes3d_proj[:, i, 0] = np.clip(bboxes3d_proj[:, i, 0], 0, img_shape[1])
+ bboxes3d_proj[:, i, 1] = np.clip(bboxes3d_proj[:, i, 1], 0, img_shape[1])
+ results[key] = bboxes3d_proj
+
+ def _resize_masks(self, results):
+ """Resize masks with ``results['scale']``"""
+ for key in results.get('mask_fields', []):
+ if results[key] is None:
+ continue
+ if self.keep_ratio:
+ results[key] = results[key].rescale(results['scale'])
+ else:
+ results[key] = results[key].resize(results['img_shape'][:2])
+
+ def _resize_seg(self, results):
+ """Resize semantic segmentation map with ``results['scale']``."""
+ for key in results.get('seg_fields', []):
+ if self.keep_ratio:
+ gt_seg = mmcv.imrescale(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ else:
+ gt_seg = mmcv.imresize(
+ results[key],
+ results['scale'],
+ interpolation='nearest',
+ backend=self.backend)
+ results['gt_semantic_seg'] = gt_seg
+
+ def __call__(self, results):
+ """Call function to resize images, bounding boxes, masks, semantic
+ segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Resized results, 'img_shape', 'pad_shape', 'scale_factor', \
+ 'keep_ratio' keys are added into result dict.
+ """
+
+ if 'scale' not in results:
+ if 'scale_factor' in results:
+ img_shape = results['img'].shape[:2]
+ scale_factor = results['scale_factor']
+ assert isinstance(scale_factor, float)
+ results['scale'] = tuple(
+ [int(x * scale_factor) for x in img_shape][::-1])
+ else:
+ self._random_scale(results)
+ else:
+ if not self.override:
+ assert 'scale_factor' not in results, (
+ 'scale and scale_factor cannot be both set.')
+ else:
+ results.pop('scale')
+ if 'scale_factor' in results:
+ results.pop('scale_factor')
+ self._random_scale(results)
+
+ self._resize_img(results)
+ self._resize_bboxes(results)
+ self._resize_bboxes3d(results)
+ self._resize_masks(results)
+ self._resize_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(img_scale={self.img_scale}, '
+ repr_str += f'multiscale_mode={self.multiscale_mode}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'keep_ratio={self.keep_ratio}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomFlip(object):
+ """Flip the image & bbox & mask.
+
+ If the input dict contains the key "flip", then the flag will be used,
+ otherwise it will be randomly decided by a ratio specified in the init
+ method.
+
+ When random flip is enabled, ``flip_ratio``/``direction`` can either be a
+ float/string or tuple of float/string. There are 3 flip modes:
+
+ - ``flip_ratio`` is float, ``direction`` is string: the image will be
+ ``direction``ly flipped with probability of ``flip_ratio`` .
+ E.g., ``flip_ratio=0.5``, ``direction='horizontal'``,
+ then image will be horizontally flipped with probability of 0.5.
+ - ``flip_ratio`` is float, ``direction`` is list of string: the image wil
+ be ``direction[i]``ly flipped with probability of
+ ``flip_ratio/len(direction)``.
+ E.g., ``flip_ratio=0.5``, ``direction=['horizontal', 'vertical']``,
+ then image will be horizontally flipped with probability of 0.25,
+ vertically with probability of 0.25.
+ - ``flip_ratio`` is list of float, ``direction`` is list of string:
+ given ``len(flip_ratio) == len(direction)``, the image wil
+ be ``direction[i]``ly flipped with probability of ``flip_ratio[i]``.
+ E.g., ``flip_ratio=[0.3, 0.5]``, ``direction=['horizontal',
+ 'vertical']``, then image will be horizontally flipped with probability
+ of 0.3, vertically with probability of 0.5
+
+ Args:
+ flip_ratio (float | list[float], optional): The flipping probability.
+ Default: None.
+ direction(str | list[str], optional): The flipping direction. Options
+ are 'horizontal', 'vertical', 'diagonal'. Default: 'horizontal'.
+ If input is a list, the length must equal ``flip_ratio``. Each
+ element in ``flip_ratio`` indicates the flip probability of
+ corresponding direction.
+ """
+
+ def __init__(self, flip_ratio=None, direction='horizontal'):
+ if isinstance(flip_ratio, list):
+ assert mmcv.is_list_of(flip_ratio, float)
+ assert 0 <= sum(flip_ratio) <= 1
+ elif isinstance(flip_ratio, float):
+ assert 0 <= flip_ratio <= 1
+ elif flip_ratio is None:
+ pass
+ else:
+ raise ValueError('flip_ratios must be None, float, '
+ 'or list of float')
+ self.flip_ratio = flip_ratio
+
+ valid_directions = ['horizontal', 'vertical', 'diagonal']
+ if isinstance(direction, str):
+ assert direction in valid_directions
+ elif isinstance(direction, list):
+ assert mmcv.is_list_of(direction, str)
+ assert set(direction).issubset(set(valid_directions))
+ else:
+ raise ValueError('direction must be either str or list of str')
+ self.direction = direction
+
+ if isinstance(flip_ratio, list):
+ assert len(self.flip_ratio) == len(self.direction)
+
+ def bbox_flip(self, bboxes, img_shape, direction):
+ """Flip bboxes horizontally.
+
+ Args:
+ bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
+ img_shape (tuple[int]): Image shape (height, width)
+ direction (str): Flip direction. Options are 'horizontal',
+ 'vertical'.
+
+ Returns:
+ numpy.ndarray: Flipped bounding boxes.
+ """
+
+ assert bboxes.shape[-1] % 4 == 0
+ flipped = bboxes.copy()
+ if direction == 'horizontal':
+ w = img_shape[1]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ elif direction == 'vertical':
+ h = img_shape[0]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ elif direction == 'diagonal':
+ w = img_shape[1]
+ h = img_shape[0]
+ flipped[..., 0::4] = w - bboxes[..., 2::4]
+ flipped[..., 1::4] = h - bboxes[..., 3::4]
+ flipped[..., 2::4] = w - bboxes[..., 0::4]
+ flipped[..., 3::4] = h - bboxes[..., 1::4]
+ else:
+ raise ValueError(f"Invalid flipping direction '{direction}'")
+ return flipped
+
+ def bbox3d_proj_flip(self, bboxes, img_shape, direction):
+ """Flip bboxes horizontally.
+
+ Args:
+ bboxes (numpy.ndarray): Bounding boxes, shape (..., 4*k)
+ img_shape (tuple[int]): Image shape (height, width)
+ direction (str): Flip direction. Options are 'horizontal',
+ 'vertical'.
+
+ Returns:
+ numpy.ndarray: Flipped bounding boxes.
+ """
+
+ flipped = bboxes.copy()
+ if direction == 'horizontal':
+ w = img_shape[1]
+
+ flipped[:,:,0] = w - bboxes[:,:, 0]
+ elif direction == 'vertical':
+ h = img_shape[0]
+ flipped[:,:,1] = h - bboxes[:,:, 1]
+ elif direction == 'diagonal':
+ w = img_shape[1]
+ h = img_shape[0]
+ flipped[:,:,0] = w - bboxes[:,:, 0]
+ flipped[:,:,1] = h - bboxes[:,:, 1]
+ else:
+ raise ValueError(f"Invalid flipping direction '{direction}'")
+ flipped[bboxes == -100] = -100
+ return flipped
+
+
+ def __call__(self, results):
+ """Call function to flip bounding boxes, masks, semantic segmentation
+ maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Flipped results, 'flip', 'flip_direction' keys are added \
+ into result dict.
+ """
+
+ if 'flip' not in results:
+ if isinstance(self.direction, list):
+ # None means non-flip
+ direction_list = self.direction + [None]
+ else:
+ # None means non-flip
+ direction_list = [self.direction, None]
+
+ if isinstance(self.flip_ratio, list):
+ non_flip_ratio = 1 - sum(self.flip_ratio)
+ flip_ratio_list = self.flip_ratio + [non_flip_ratio]
+ else:
+ non_flip_ratio = 1 - self.flip_ratio
+ # exclude non-flip
+ single_ratio = self.flip_ratio / (len(direction_list) - 1)
+ flip_ratio_list = [single_ratio] * (len(direction_list) -
+ 1) + [non_flip_ratio]
+
+ cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
+
+ results['flip'] = cur_dir is not None
+ if 'flip_direction' not in results:
+ results['flip_direction'] = cur_dir
+ if results['flip']:
+ # flip image
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ # flip bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = self.bbox_flip(results[key],
+ results['img_shape'],
+ results['flip_direction'])
+ for key in results.get('bbox3d_fields', []):
+ if '_proj' in key:
+ results[key] = self.bbox3d_proj_flip(results[key],
+ results['img_shape'],
+ results['flip_direction'])
+ # flip masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].flip(results['flip_direction'])
+
+ # flip segs
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.imflip(
+ results[key], direction=results['flip_direction'])
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(flip_ratio={self.flip_ratio})'
+
+
+@PIPELINES.register_module()
+class Pad(object):
+ """Pad the image & mask.
+
+ There are two padding modes: (1) pad to a fixed size and (2) pad to the
+ minimum size that is divisible by some number.
+ Added keys are "pad_shape", "pad_fixed_size", "pad_size_divisor",
+
+ Args:
+ size (tuple, optional): Fixed padding size.
+ size_divisor (int, optional): The divisor of padded size.
+ pad_val (float, optional): Padding value, 0 by default.
+ """
+
+ def __init__(self, size=None, size_divisor=None, pad_val=0):
+ self.size = size
+ self.size_divisor = size_divisor
+ self.pad_val = pad_val
+ # only one of size and size_divisor should be valid
+ assert size is not None or size_divisor is not None
+ assert size is None or size_divisor is None
+
+ def _pad_img(self, results):
+ """Pad images according to ``self.size``."""
+ for key in results.get('img_fields', ['img']):
+ if self.size is not None:
+ padded_img = mmcv.impad(
+ results[key], shape=self.size, pad_val=self.pad_val)
+ elif self.size_divisor is not None:
+ padded_img = mmcv.impad_to_multiple(
+ results[key], self.size_divisor, pad_val=self.pad_val)
+ results[key] = padded_img
+ results['pad_shape'] = padded_img.shape
+ results['pad_fixed_size'] = self.size
+ results['pad_size_divisor'] = self.size_divisor
+
+ def _pad_masks(self, results):
+ """Pad masks according to ``results['pad_shape']``."""
+ pad_shape = results['pad_shape'][:2]
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].pad(pad_shape, pad_val=self.pad_val)
+
+ def _pad_seg(self, results):
+ """Pad semantic segmentation map according to
+ ``results['pad_shape']``."""
+ for key in results.get('seg_fields', []):
+ results[key] = mmcv.impad(
+ results[key], shape=results['pad_shape'][:2])
+
+ def __call__(self, results):
+ """Call function to pad images, masks, semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Updated result dict.
+ """
+ self._pad_img(results)
+ self._pad_masks(results)
+ self._pad_seg(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(size={self.size}, '
+ repr_str += f'size_divisor={self.size_divisor}, '
+ repr_str += f'pad_val={self.pad_val})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Normalize(object):
+ """Normalize the image.
+
+ Added key is "img_norm_cfg".
+
+ Args:
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB,
+ default is true.
+ """
+
+ def __init__(self, mean, std, to_rgb=True):
+ self.mean = np.array(mean, dtype=np.float32)
+ self.std = np.array(std, dtype=np.float32)
+ self.to_rgb = to_rgb
+
+ def __call__(self, results):
+ """Call function to normalize images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Normalized results, 'img_norm_cfg' key is added into
+ result dict.
+ """
+ for key in results.get('img_fields', ['img']):
+ results[key] = mmcv.imnormalize(results[key], self.mean, self.std,
+ self.to_rgb)
+ results['img_norm_cfg'] = dict(
+ mean=self.mean, std=self.std, to_rgb=self.to_rgb)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, std={self.std}, to_rgb={self.to_rgb})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCrop(object):
+ """Random crop the image & bboxes & masks.
+
+ The absolute `crop_size` is sampled based on `crop_type` and `image_size`,
+ then the cropped results are generated.
+
+ Args:
+ crop_size (tuple): The relative ratio or absolute pixels of
+ height and width.
+ crop_type (str, optional): one of "relative_range", "relative",
+ "absolute", "absolute_range". "relative" randomly crops
+ (h * crop_size[0], w * crop_size[1]) part from an input of size
+ (h, w). "relative_range" uniformly samples relative crop size from
+ range [crop_size[0], 1] and [crop_size[1], 1] for height and width
+ respectively. "absolute" crops from an input with absolute size
+ (crop_size[0], crop_size[1]). "absolute_range" uniformly samples
+ crop_h in range [crop_size[0], min(h, crop_size[1])] and crop_w
+ in range [crop_size[0], min(w, crop_size[1])]. Default "absolute".
+ allow_negative_crop (bool, optional): Whether to allow a crop that does
+ not contain any bbox area. Default False.
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ - If the image is smaller than the absolute crop size, return the
+ original image.
+ - The keys for bboxes, labels and masks must be aligned. That is,
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and
+ `gt_bboxes_ignore` corresponds to `gt_labels_ignore` and
+ `gt_masks_ignore`.
+ - If the crop does not contain any gt-bbox region and
+ `allow_negative_crop` is set to False, skip this image.
+ """
+
+ def __init__(self,
+ crop_size,
+ crop_type='absolute',
+ allow_negative_crop=False,
+ bbox_clip_border=True):
+ if crop_type not in [
+ 'relative_range', 'relative', 'absolute', 'absolute_range'
+ ]:
+ raise ValueError(f'Invalid crop_type {crop_type}.')
+ if crop_type in ['absolute', 'absolute_range']:
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ assert isinstance(crop_size[0], int) and isinstance(
+ crop_size[1], int)
+ else:
+ assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
+ self.crop_size = crop_size
+ self.crop_type = crop_type
+ self.allow_negative_crop = allow_negative_crop
+ self.bbox_clip_border = bbox_clip_border
+ # The key correspondence from bboxes to labels and masks.
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def _crop_data(self, results, crop_size, allow_negative_crop):
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_size (tuple): Expected absolute size after cropping, (h, w).
+ allow_negative_crop (bool): Whether to allow a crop that does not
+ contain any bbox area. Default to False.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ for key in results.get('img_fields', ['img']):
+ img = results[key]
+ margin_h = max(img.shape[0] - crop_size[0], 0)
+ margin_w = max(img.shape[1] - crop_size[1], 0)
+ offset_h = np.random.randint(0, margin_h + 1)
+ offset_w = np.random.randint(0, margin_w + 1)
+ crop_y1, crop_y2 = offset_h, offset_h + crop_size[0]
+ crop_x1, crop_x2 = offset_w, offset_w + crop_size[1]
+
+ # crop the image
+ img = img[crop_y1:crop_y2, crop_x1:crop_x2, ...]
+ img_shape = img.shape
+ results[key] = img
+ results['img_shape'] = img_shape
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ # e.g. gt_bboxes and gt_bboxes_ignore
+ bbox_offset = np.array([offset_w, offset_h, offset_w, offset_h],
+ dtype=np.float32)
+ bboxes = results[key] - bbox_offset
+ if self.bbox_clip_border:
+ bboxes[:, 0::2] = np.clip(bboxes[:, 0::2], 0, img_shape[1])
+ bboxes[:, 1::2] = np.clip(bboxes[:, 1::2], 0, img_shape[0])
+ valid_inds = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ # If the crop does not contain any gt-bbox area and
+ # allow_negative_crop is False, skip this image.
+ if (key == 'gt_bboxes' and not valid_inds.any()
+ and not allow_negative_crop):
+ return None
+ results[key] = bboxes[valid_inds, :]
+ # label fields. e.g. gt_labels and gt_labels_ignore
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][valid_inds]
+
+ # mask fields, e.g. gt_masks and gt_masks_ignore
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ valid_inds.nonzero()[0]].crop(
+ np.asarray([crop_x1, crop_y1, crop_x2, crop_y2]))
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][crop_y1:crop_y2, crop_x1:crop_x2]
+
+ return results
+
+ def _get_crop_size(self, image_size):
+ """Randomly generates the absolute crop size based on `crop_type` and
+ `image_size`.
+
+ Args:
+ image_size (tuple): (h, w).
+
+ Returns:
+ crop_size (tuple): (crop_h, crop_w) in absolute pixels.
+ """
+ h, w = image_size
+ if self.crop_type == 'absolute':
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
+ elif self.crop_type == 'absolute_range':
+ assert self.crop_size[0] <= self.crop_size[1]
+ crop_h = np.random.randint(
+ min(h, self.crop_size[0]),
+ min(h, self.crop_size[1]) + 1)
+ crop_w = np.random.randint(
+ min(w, self.crop_size[0]),
+ min(w, self.crop_size[1]) + 1)
+ return crop_h, crop_w
+ elif self.crop_type == 'relative':
+ crop_h, crop_w = self.crop_size
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+ elif self.crop_type == 'relative_range':
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+
+ def __call__(self, results):
+ """Call function to randomly crop images, bounding boxes, masks,
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'img_shape' key in result dict is
+ updated according to crop size.
+ """
+ image_size = results['img'].shape[:2]
+ crop_size = self._get_crop_size(image_size)
+ results = self._crop_data(results, crop_size, self.allow_negative_crop)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'crop_type={self.crop_type}, '
+ repr_str += f'allow_negative_crop={self.allow_negative_crop}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class SegRescale(object):
+ """Rescale semantic segmentation maps.
+
+ Args:
+ scale_factor (float): The scale factor of the final output.
+ backend (str): Image rescale backend, choices are 'cv2' and 'pillow'.
+ These two backends generates slightly different results. Defaults
+ to 'cv2'.
+ """
+
+ def __init__(self, scale_factor=1, backend='cv2'):
+ self.scale_factor = scale_factor
+ self.backend = backend
+
+ def __call__(self, results):
+ """Call function to scale the semantic segmentation map.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with semantic segmentation map scaled.
+ """
+
+ for key in results.get('seg_fields', []):
+ if self.scale_factor != 1:
+ results[key] = mmcv.imrescale(
+ results[key],
+ self.scale_factor,
+ interpolation='nearest',
+ backend=self.backend)
+ return results
+
+ def __repr__(self):
+ return self.__class__.__name__ + f'(scale_factor={self.scale_factor})'
+
+
+@PIPELINES.register_module()
+class PhotoMetricDistortion(object):
+ """Apply photometric distortion to image sequentially, every transformation
+ is applied with a probability of 0.5. The position of random contrast is in
+ second or second to last.
+
+ 1. random brightness
+ 2. random contrast (mode 0)
+ 3. convert color from BGR to HSV
+ 4. random saturation
+ 5. random hue
+ 6. convert color from HSV to BGR
+ 7. random contrast (mode 1)
+ 8. randomly swap channels
+
+ Args:
+ brightness_delta (int): delta of brightness.
+ contrast_range (tuple): range of contrast.
+ saturation_range (tuple): range of saturation.
+ hue_delta (int): delta of hue.
+ """
+
+ def __init__(self,
+ brightness_delta=32,
+ contrast_range=(0.5, 1.5),
+ saturation_range=(0.5, 1.5),
+ hue_delta=18):
+ self.brightness_delta = brightness_delta
+ self.contrast_lower, self.contrast_upper = contrast_range
+ self.saturation_lower, self.saturation_upper = saturation_range
+ self.hue_delta = hue_delta
+
+ def __call__(self, results):
+ """Call function to perform photometric distortion on images.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images distorted.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ assert img.dtype == np.float32, \
+ 'PhotoMetricDistortion needs the input image of dtype np.float32,'\
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline'
+ # random brightness
+ if random.randint(2):
+ delta = random.uniform(-self.brightness_delta,
+ self.brightness_delta)
+ img += delta
+
+ # mode == 0 --> do random contrast first
+ # mode == 1 --> do random contrast last
+ mode = random.randint(2)
+ if mode == 1:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # convert color from BGR to HSV
+ img = mmcv.bgr2hsv(img)
+
+ # random saturation
+ if random.randint(2):
+ img[..., 1] *= random.uniform(self.saturation_lower,
+ self.saturation_upper)
+
+ # random hue
+ if random.randint(2):
+ img[..., 0] += random.uniform(-self.hue_delta, self.hue_delta)
+ img[..., 0][img[..., 0] > 360] -= 360
+ img[..., 0][img[..., 0] < 0] += 360
+
+ # convert color from HSV to BGR
+ img = mmcv.hsv2bgr(img)
+
+ # random contrast
+ if mode == 0:
+ if random.randint(2):
+ alpha = random.uniform(self.contrast_lower,
+ self.contrast_upper)
+ img *= alpha
+
+ # randomly swap channels
+ if random.randint(2):
+ img = img[..., random.permutation(3)]
+
+ results['img'] = img
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(\nbrightness_delta={self.brightness_delta},\n'
+ repr_str += 'contrast_range='
+ repr_str += f'{(self.contrast_lower, self.contrast_upper)},\n'
+ repr_str += 'saturation_range='
+ repr_str += f'{(self.saturation_lower, self.saturation_upper)},\n'
+ repr_str += f'hue_delta={self.hue_delta})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Expand(object):
+ """Random expand the image & bboxes.
+
+ Randomly place the original image on a canvas of 'ratio' x original image
+ size filled with mean values. The ratio is in the range of ratio_range.
+
+ Args:
+ mean (tuple): mean value of dataset.
+ to_rgb (bool): if need to convert the order of mean to align with RGB.
+ ratio_range (tuple): range of expand ratio.
+ prob (float): probability of applying this transformation
+ """
+
+ def __init__(self,
+ mean=(0, 0, 0),
+ to_rgb=True,
+ ratio_range=(1, 4),
+ seg_ignore_label=None,
+ prob=0.5):
+ self.to_rgb = to_rgb
+ self.ratio_range = ratio_range
+ if to_rgb:
+ self.mean = mean[::-1]
+ else:
+ self.mean = mean
+ self.min_ratio, self.max_ratio = ratio_range
+ self.seg_ignore_label = seg_ignore_label
+ self.prob = prob
+
+ def __call__(self, results):
+ """Call function to expand images, bounding boxes.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images, bounding boxes expanded
+ """
+
+ if random.uniform(0, 1) > self.prob:
+ return results
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+
+ h, w, c = img.shape
+ ratio = random.uniform(self.min_ratio, self.max_ratio)
+ # speedup expand when meets large image
+ if np.all(self.mean == self.mean[0]):
+ expand_img = np.empty((int(h * ratio), int(w * ratio), c),
+ img.dtype)
+ expand_img.fill(self.mean[0])
+ else:
+ expand_img = np.full((int(h * ratio), int(w * ratio), c),
+ self.mean,
+ dtype=img.dtype)
+ left = int(random.uniform(0, w * ratio - w))
+ top = int(random.uniform(0, h * ratio - h))
+ expand_img[top:top + h, left:left + w] = img
+
+ results['img'] = expand_img
+ # expand bboxes
+ for key in results.get('bbox_fields', []):
+ results[key] = results[key] + np.tile(
+ (left, top), 2).astype(results[key].dtype)
+
+ # expand masks
+ for key in results.get('mask_fields', []):
+ results[key] = results[key].expand(
+ int(h * ratio), int(w * ratio), top, left)
+
+ # expand segs
+ for key in results.get('seg_fields', []):
+ gt_seg = results[key]
+ expand_gt_seg = np.full((int(h * ratio), int(w * ratio)),
+ self.seg_ignore_label,
+ dtype=gt_seg.dtype)
+ expand_gt_seg[top:top + h, left:left + w] = gt_seg
+ results[key] = expand_gt_seg
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(mean={self.mean}, to_rgb={self.to_rgb}, '
+ repr_str += f'ratio_range={self.ratio_range}, '
+ repr_str += f'seg_ignore_label={self.seg_ignore_label})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class MinIoURandomCrop(object):
+ """Random crop the image & bboxes, the cropped patches have minimum IoU
+ requirement with original image & bboxes, the IoU threshold is randomly
+ selected from min_ious.
+
+ Args:
+ min_ious (tuple): minimum IoU threshold for all intersections with
+ bounding boxes
+ min_crop_size (float): minimum crop's size (i.e. h,w := a*h, a*w,
+ where a >= min_crop_size).
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+
+ Note:
+ The keys for bboxes, labels and masks should be paired. That is, \
+ `gt_bboxes` corresponds to `gt_labels` and `gt_masks`, and \
+ `gt_bboxes_ignore` to `gt_labels_ignore` and `gt_masks_ignore`.
+ """
+
+ def __init__(self,
+ min_ious=(0.1, 0.3, 0.5, 0.7, 0.9),
+ min_crop_size=0.3,
+ bbox_clip_border=True):
+ # 1: return ori img
+ self.min_ious = min_ious
+ self.sample_mode = (1, *min_ious, 0)
+ self.min_crop_size = min_crop_size
+ self.bbox_clip_border = bbox_clip_border
+ self.bbox2label = {
+ 'gt_bboxes': 'gt_labels',
+ 'gt_bboxes_ignore': 'gt_labels_ignore'
+ }
+ self.bbox2mask = {
+ 'gt_bboxes': 'gt_masks',
+ 'gt_bboxes_ignore': 'gt_masks_ignore'
+ }
+
+ def __call__(self, results):
+ """Call function to crop images and bounding boxes with minimum IoU
+ constraint.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images and bounding boxes cropped, \
+ 'img_shape' key is updated.
+ """
+
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ img = results['img']
+ assert 'bbox_fields' in results
+ boxes = [results[key] for key in results['bbox_fields']]
+ boxes = np.concatenate(boxes, 0)
+ h, w, c = img.shape
+ while True:
+ mode = random.choice(self.sample_mode)
+ self.mode = mode
+ if mode == 1:
+ return results
+
+ min_iou = mode
+ for i in range(50):
+ new_w = random.uniform(self.min_crop_size * w, w)
+ new_h = random.uniform(self.min_crop_size * h, h)
+
+ # h / w in [0.5, 2]
+ if new_h / new_w < 0.5 or new_h / new_w > 2:
+ continue
+
+ left = random.uniform(w - new_w)
+ top = random.uniform(h - new_h)
+
+ patch = np.array(
+ (int(left), int(top), int(left + new_w), int(top + new_h)))
+ # Line or point crop is not allowed
+ if patch[2] == patch[0] or patch[3] == patch[1]:
+ continue
+ overlaps = bbox_overlaps(
+ patch.reshape(-1, 4), boxes.reshape(-1, 4)).reshape(-1)
+ if len(overlaps) > 0 and overlaps.min() < min_iou:
+ continue
+
+ # center of boxes should inside the crop img
+ # only adjust boxes and instance masks when the gt is not empty
+ if len(overlaps) > 0:
+ # adjust boxes
+ def is_center_of_bboxes_in_patch(boxes, patch):
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = ((center[:, 0] > patch[0]) *
+ (center[:, 1] > patch[1]) *
+ (center[:, 0] < patch[2]) *
+ (center[:, 1] < patch[3]))
+ return mask
+
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ if not mask.any():
+ continue
+ for key in results.get('bbox_fields', []):
+ boxes = results[key].copy()
+ mask = is_center_of_bboxes_in_patch(boxes, patch)
+ boxes = boxes[mask]
+ if self.bbox_clip_border:
+ boxes[:, 2:] = boxes[:, 2:].clip(max=patch[2:])
+ boxes[:, :2] = boxes[:, :2].clip(min=patch[:2])
+ boxes -= np.tile(patch[:2], 2)
+
+ results[key] = boxes
+ # labels
+ label_key = self.bbox2label.get(key)
+ if label_key in results:
+ results[label_key] = results[label_key][mask]
+
+ # mask fields
+ mask_key = self.bbox2mask.get(key)
+ if mask_key in results:
+ results[mask_key] = results[mask_key][
+ mask.nonzero()[0]].crop(patch)
+ # adjust the img no matter whether the gt is empty before crop
+ img = img[patch[1]:patch[3], patch[0]:patch[2]]
+ results['img'] = img
+ results['img_shape'] = img.shape
+
+ # seg fields
+ for key in results.get('seg_fields', []):
+ results[key] = results[key][patch[1]:patch[3],
+ patch[0]:patch[2]]
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(min_ious={self.min_ious}, '
+ repr_str += f'min_crop_size={self.min_crop_size}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Corrupt(object):
+ """Corruption augmentation.
+
+ Corruption transforms implemented based on
+ `imagecorruptions `_.
+
+ Args:
+ corruption (str): Corruption name.
+ severity (int, optional): The severity of corruption. Default: 1.
+ """
+
+ def __init__(self, corruption, severity=1):
+ self.corruption = corruption
+ self.severity = severity
+
+ def __call__(self, results):
+ """Call function to corrupt image.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Result dict with images corrupted.
+ """
+
+ if corrupt is None:
+ raise RuntimeError('imagecorruptions is not installed')
+ if 'img_fields' in results:
+ assert results['img_fields'] == ['img'], \
+ 'Only single img_fields is allowed'
+ results['img'] = corrupt(
+ results['img'].astype(np.uint8),
+ corruption_name=self.corruption,
+ severity=self.severity)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(corruption={self.corruption}, '
+ repr_str += f'severity={self.severity})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class Albu(object):
+ """Albumentation augmentation.
+
+ Adds custom transformations from Albumentations library.
+ Please, visit `https://albumentations.readthedocs.io`
+ to get more information.
+
+ An example of ``transforms`` is as followed:
+
+ .. code-block::
+
+ [
+ dict(
+ type='ShiftScaleRotate',
+ shift_limit=0.0625,
+ scale_limit=0.0,
+ rotate_limit=0,
+ interpolation=1,
+ p=0.5),
+ dict(
+ type='RandomBrightnessContrast',
+ brightness_limit=[0.1, 0.3],
+ contrast_limit=[0.1, 0.3],
+ p=0.2),
+ dict(type='ChannelShuffle', p=0.1),
+ dict(
+ type='OneOf',
+ transforms=[
+ dict(type='Blur', blur_limit=3, p=1.0),
+ dict(type='MedianBlur', blur_limit=3, p=1.0)
+ ],
+ p=0.1),
+ ]
+
+ Args:
+ transforms (list[dict]): A list of albu transformations
+ bbox_params (dict): Bbox_params for albumentation `Compose`
+ keymap (dict): Contains {'input key':'albumentation-style key'}
+ skip_img_without_anno (bool): Whether to skip the image if no ann left
+ after aug
+ """
+
+ def __init__(self,
+ transforms,
+ bbox_params=None,
+ keymap=None,
+ update_pad_shape=False,
+ skip_img_without_anno=False):
+ if Compose is None:
+ raise RuntimeError('albumentations is not installed')
+
+ # Args will be modified later, copying it will be safer
+ transforms = copy.deepcopy(transforms)
+ if bbox_params is not None:
+ bbox_params = copy.deepcopy(bbox_params)
+ if keymap is not None:
+ keymap = copy.deepcopy(keymap)
+ self.transforms = transforms
+ self.filter_lost_elements = False
+ self.update_pad_shape = update_pad_shape
+ self.skip_img_without_anno = skip_img_without_anno
+
+ # A simple workaround to remove masks without boxes
+ if (isinstance(bbox_params, dict) and 'label_fields' in bbox_params
+ and 'filter_lost_elements' in bbox_params):
+ self.filter_lost_elements = True
+ self.origin_label_fields = bbox_params['label_fields']
+ bbox_params['label_fields'] = ['idx_mapper']
+ del bbox_params['filter_lost_elements']
+
+ self.bbox_params = (
+ self.albu_builder(bbox_params) if bbox_params else None)
+ self.aug = Compose([self.albu_builder(t) for t in self.transforms],
+ bbox_params=self.bbox_params)
+
+ if not keymap:
+ self.keymap_to_albu = {
+ 'img': 'image',
+ 'gt_masks': 'masks',
+ 'gt_bboxes': 'bboxes'
+ }
+ else:
+ self.keymap_to_albu = keymap
+ self.keymap_back = {v: k for k, v in self.keymap_to_albu.items()}
+
+ def albu_builder(self, cfg):
+ """Import a module from albumentations.
+
+ It inherits some of :func:`build_from_cfg` logic.
+
+ Args:
+ cfg (dict): Config dict. It should at least contain the key "type".
+
+ Returns:
+ obj: The constructed object.
+ """
+
+ assert isinstance(cfg, dict) and 'type' in cfg
+ args = cfg.copy()
+
+ obj_type = args.pop('type')
+ if mmcv.is_str(obj_type):
+ if albumentations is None:
+ raise RuntimeError('albumentations is not installed')
+ obj_cls = getattr(albumentations, obj_type)
+ elif inspect.isclass(obj_type):
+ obj_cls = obj_type
+ else:
+ raise TypeError(
+ f'type must be a str or valid type, but got {type(obj_type)}')
+
+ if 'transforms' in args:
+ args['transforms'] = [
+ self.albu_builder(transform)
+ for transform in args['transforms']
+ ]
+
+ return obj_cls(**args)
+
+ @staticmethod
+ def mapper(d, keymap):
+ """Dictionary mapper. Renames keys according to keymap provided.
+
+ Args:
+ d (dict): old dict
+ keymap (dict): {'old_key':'new_key'}
+ Returns:
+ dict: new dict.
+ """
+
+ updated_dict = {}
+ for k, v in zip(d.keys(), d.values()):
+ new_k = keymap.get(k, k)
+ updated_dict[new_k] = d[k]
+ return updated_dict
+
+ def __call__(self, results):
+ # dict to albumentations format
+ results = self.mapper(results, self.keymap_to_albu)
+ # TODO: add bbox_fields
+ if 'bboxes' in results:
+ # to list of boxes
+ if isinstance(results['bboxes'], np.ndarray):
+ results['bboxes'] = [x for x in results['bboxes']]
+ # add pseudo-field for filtration
+ if self.filter_lost_elements:
+ results['idx_mapper'] = np.arange(len(results['bboxes']))
+
+ # TODO: Support mask structure in albu
+ if 'masks' in results:
+ if isinstance(results['masks'], PolygonMasks):
+ raise NotImplementedError(
+ 'Albu only supports BitMap masks now')
+ ori_masks = results['masks']
+ if albumentations.__version__ < '0.5':
+ results['masks'] = results['masks'].masks
+ else:
+ results['masks'] = [mask for mask in results['masks'].masks]
+
+ results = self.aug(**results)
+
+ if 'bboxes' in results:
+ if isinstance(results['bboxes'], list):
+ results['bboxes'] = np.array(
+ results['bboxes'], dtype=np.float32)
+ results['bboxes'] = results['bboxes'].reshape(-1, 4)
+
+ # filter label_fields
+ if self.filter_lost_elements:
+
+ for label in self.origin_label_fields:
+ results[label] = np.array(
+ [results[label][i] for i in results['idx_mapper']])
+ if 'masks' in results:
+ results['masks'] = np.array(
+ [results['masks'][i] for i in results['idx_mapper']])
+ results['masks'] = ori_masks.__class__(
+ results['masks'], results['image'].shape[0],
+ results['image'].shape[1])
+
+ if (not len(results['idx_mapper'])
+ and self.skip_img_without_anno):
+ return None
+
+ if 'gt_labels' in results:
+ if isinstance(results['gt_labels'], list):
+ results['gt_labels'] = np.array(results['gt_labels'])
+ results['gt_labels'] = results['gt_labels'].astype(np.int64)
+
+ # back to the original format
+ results = self.mapper(results, self.keymap_back)
+
+ # update final shape
+ if self.update_pad_shape:
+ results['pad_shape'] = results['img'].shape
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__ + f'(transforms={self.transforms})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class RandomCenterCropPad(object):
+ """Random center crop and random around padding for CornerNet.
+
+ This operation generates randomly cropped image from the original image and
+ pads it simultaneously. Different from :class:`RandomCrop`, the output
+ shape may not equal to ``crop_size`` strictly. We choose a random value
+ from ``ratios`` and the output shape could be larger or smaller than
+ ``crop_size``. The padding operation is also different from :class:`Pad`,
+ here we use around padding instead of right-bottom padding.
+
+ The relation between output image (padding image) and original image:
+
+ .. code:: text
+
+ output image
+
+ +----------------------------+
+ | padded area |
+ +------|----------------------------|----------+
+ | | cropped area | |
+ | | +---------------+ | |
+ | | | . center | | | original image
+ | | | range | | |
+ | | +---------------+ | |
+ +------|----------------------------|----------+
+ | padded area |
+ +----------------------------+
+
+ There are 5 main areas in the figure:
+
+ - output image: output image of this operation, also called padding
+ image in following instruction.
+ - original image: input image of this operation.
+ - padded area: non-intersect area of output image and original image.
+ - cropped area: the overlap of output image and original image.
+ - center range: a smaller area where random center chosen from.
+ center range is computed by ``border`` and original image's shape
+ to avoid our random center is too close to original image's border.
+
+ Also this operation act differently in train and test mode, the summary
+ pipeline is listed below.
+
+ Train pipeline:
+
+ 1. Choose a ``random_ratio`` from ``ratios``, the shape of padding image
+ will be ``random_ratio * crop_size``.
+ 2. Choose a ``random_center`` in center range.
+ 3. Generate padding image with center matches the ``random_center``.
+ 4. Initialize the padding image with pixel value equals to ``mean``.
+ 5. Copy the cropped area to padding image.
+ 6. Refine annotations.
+
+ Test pipeline:
+
+ 1. Compute output shape according to ``test_pad_mode``.
+ 2. Generate padding image with center matches the original image
+ center.
+ 3. Initialize the padding image with pixel value equals to ``mean``.
+ 4. Copy the ``cropped area`` to padding image.
+
+ Args:
+ crop_size (tuple | None): expected size after crop, final size will
+ computed according to ratio. Requires (h, w) in train mode, and
+ None in test mode.
+ ratios (tuple): random select a ratio from tuple and crop image to
+ (crop_size[0] * ratio) * (crop_size[1] * ratio).
+ Only available in train mode.
+ border (int): max distance from center select area to image border.
+ Only available in train mode.
+ mean (sequence): Mean values of 3 channels.
+ std (sequence): Std values of 3 channels.
+ to_rgb (bool): Whether to convert the image from BGR to RGB.
+ test_mode (bool): whether involve random variables in transform.
+ In train mode, crop_size is fixed, center coords and ratio is
+ random selected from predefined lists. In test mode, crop_size
+ is image's original shape, center coords and ratio is fixed.
+ test_pad_mode (tuple): padding method and padding shape value, only
+ available in test mode. Default is using 'logical_or' with
+ 127 as padding shape value.
+
+ - 'logical_or': final_shape = input_shape | padding_shape_value
+ - 'size_divisor': final_shape = int(
+ ceil(input_shape / padding_shape_value) * padding_shape_value)
+ bbox_clip_border (bool, optional): Whether clip the objects outside
+ the border of the image. Defaults to True.
+ """
+
+ def __init__(self,
+ crop_size=None,
+ ratios=(0.9, 1.0, 1.1),
+ border=128,
+ mean=None,
+ std=None,
+ to_rgb=None,
+ test_mode=False,
+ test_pad_mode=('logical_or', 127),
+ bbox_clip_border=True):
+ if test_mode:
+ assert crop_size is None, 'crop_size must be None in test mode'
+ assert ratios is None, 'ratios must be None in test mode'
+ assert border is None, 'border must be None in test mode'
+ assert isinstance(test_pad_mode, (list, tuple))
+ assert test_pad_mode[0] in ['logical_or', 'size_divisor']
+ else:
+ assert isinstance(crop_size, (list, tuple))
+ assert crop_size[0] > 0 and crop_size[1] > 0, (
+ 'crop_size must > 0 in train mode')
+ assert isinstance(ratios, (list, tuple))
+ assert test_pad_mode is None, (
+ 'test_pad_mode must be None in train mode')
+
+ self.crop_size = crop_size
+ self.ratios = ratios
+ self.border = border
+ # We do not set default value to mean, std and to_rgb because these
+ # hyper-parameters are easy to forget but could affect the performance.
+ # Please use the same setting as Normalize for performance assurance.
+ assert mean is not None and std is not None and to_rgb is not None
+ self.to_rgb = to_rgb
+ self.input_mean = mean
+ self.input_std = std
+ if to_rgb:
+ self.mean = mean[::-1]
+ self.std = std[::-1]
+ else:
+ self.mean = mean
+ self.std = std
+ self.test_mode = test_mode
+ self.test_pad_mode = test_pad_mode
+ self.bbox_clip_border = bbox_clip_border
+
+ def _get_border(self, border, size):
+ """Get final border for the target size.
+
+ This function generates a ``final_border`` according to image's shape.
+ The area between ``final_border`` and ``size - final_border`` is the
+ ``center range``. We randomly choose center from the ``center range``
+ to avoid our random center is too close to original image's border.
+ Also ``center range`` should be larger than 0.
+
+ Args:
+ border (int): The initial border, default is 128.
+ size (int): The width or height of original image.
+ Returns:
+ int: The final border.
+ """
+ k = 2 * border / size
+ i = pow(2, np.ceil(np.log2(np.ceil(k))) + (k == int(k)))
+ return border // i
+
+ def _filter_boxes(self, patch, boxes):
+ """Check whether the center of each box is in the patch.
+
+ Args:
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ boxes (numpy array, (N x 4)): Ground truth boxes.
+
+ Returns:
+ mask (numpy array, (N,)): Each box is inside or outside the patch.
+ """
+ center = (boxes[:, :2] + boxes[:, 2:]) / 2
+ mask = (center[:, 0] > patch[0]) * (center[:, 1] > patch[1]) * (
+ center[:, 0] < patch[2]) * (
+ center[:, 1] < patch[3])
+ return mask
+
+ def _crop_image_and_paste(self, image, center, size):
+ """Crop image with a given center and size, then paste the cropped
+ image to a blank image with two centers align.
+
+ This function is equivalent to generating a blank image with ``size``
+ as its shape. Then cover it on the original image with two centers (
+ the center of blank image and the random center of original image)
+ aligned. The overlap area is paste from the original image and the
+ outside area is filled with ``mean pixel``.
+
+ Args:
+ image (np array, H x W x C): Original image.
+ center (list[int]): Target crop center coord.
+ size (list[int]): Target crop size. [target_h, target_w]
+
+ Returns:
+ cropped_img (np array, target_h x target_w x C): Cropped image.
+ border (np array, 4): The distance of four border of
+ ``cropped_img`` to the original image area, [top, bottom,
+ left, right]
+ patch (list[int]): The cropped area, [left, top, right, bottom].
+ """
+ center_y, center_x = center
+ target_h, target_w = size
+ img_h, img_w, img_c = image.shape
+
+ x0 = max(0, center_x - target_w // 2)
+ x1 = min(center_x + target_w // 2, img_w)
+ y0 = max(0, center_y - target_h // 2)
+ y1 = min(center_y + target_h // 2, img_h)
+ patch = np.array((int(x0), int(y0), int(x1), int(y1)))
+
+ left, right = center_x - x0, x1 - center_x
+ top, bottom = center_y - y0, y1 - center_y
+
+ cropped_center_y, cropped_center_x = target_h // 2, target_w // 2
+ cropped_img = np.zeros((target_h, target_w, img_c), dtype=image.dtype)
+ for i in range(img_c):
+ cropped_img[:, :, i] += self.mean[i]
+ y_slice = slice(cropped_center_y - top, cropped_center_y + bottom)
+ x_slice = slice(cropped_center_x - left, cropped_center_x + right)
+ cropped_img[y_slice, x_slice, :] = image[y0:y1, x0:x1, :]
+
+ border = np.array([
+ cropped_center_y - top, cropped_center_y + bottom,
+ cropped_center_x - left, cropped_center_x + right
+ ],
+ dtype=np.float32)
+
+ return cropped_img, border, patch
+
+ def _train_aug(self, results):
+ """Random crop and around padding the original image.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ boxes = results['gt_bboxes']
+ while True:
+ scale = random.choice(self.ratios)
+ new_h = int(self.crop_size[0] * scale)
+ new_w = int(self.crop_size[1] * scale)
+ h_border = self._get_border(self.border, h)
+ w_border = self._get_border(self.border, w)
+
+ for i in range(50):
+ center_x = random.randint(low=w_border, high=w - w_border)
+ center_y = random.randint(low=h_border, high=h - h_border)
+
+ cropped_img, border, patch = self._crop_image_and_paste(
+ img, [center_y, center_x], [new_h, new_w])
+
+ mask = self._filter_boxes(patch, boxes)
+ # if image do not have valid bbox, any crop patch is valid.
+ if not mask.any() and len(boxes) > 0:
+ continue
+
+ results['img'] = cropped_img
+ results['img_shape'] = cropped_img.shape
+ results['pad_shape'] = cropped_img.shape
+
+ x0, y0, x1, y1 = patch
+
+ left_w, top_h = center_x - x0, center_y - y0
+ cropped_center_x, cropped_center_y = new_w // 2, new_h // 2
+
+ # crop bboxes accordingly and clip to the image boundary
+ for key in results.get('bbox_fields', []):
+ mask = self._filter_boxes(patch, results[key])
+ bboxes = results[key][mask]
+ bboxes[:, 0:4:2] += cropped_center_x - left_w - x0
+ bboxes[:, 1:4:2] += cropped_center_y - top_h - y0
+ if self.bbox_clip_border:
+ bboxes[:, 0:4:2] = np.clip(bboxes[:, 0:4:2], 0, new_w)
+ bboxes[:, 1:4:2] = np.clip(bboxes[:, 1:4:2], 0, new_h)
+ keep = (bboxes[:, 2] > bboxes[:, 0]) & (
+ bboxes[:, 3] > bboxes[:, 1])
+ bboxes = bboxes[keep]
+ results[key] = bboxes
+ if key in ['gt_bboxes']:
+ if 'gt_labels' in results:
+ labels = results['gt_labels'][mask]
+ labels = labels[keep]
+ results['gt_labels'] = labels
+ if 'gt_masks' in results:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+
+ # crop semantic seg
+ for key in results.get('seg_fields', []):
+ raise NotImplementedError(
+ 'RandomCenterCropPad only supports bbox.')
+ return results
+
+ def _test_aug(self, results):
+ """Around padding the original image without cropping.
+
+ The padding mode and value are from ``test_pad_mode``.
+
+ Args:
+ results (dict): Image infomations in the augment pipeline.
+
+ Returns:
+ results (dict): The updated dict.
+ """
+ img = results['img']
+ h, w, c = img.shape
+ results['img_shape'] = img.shape
+ if self.test_pad_mode[0] in ['logical_or']:
+ target_h = h | self.test_pad_mode[1]
+ target_w = w | self.test_pad_mode[1]
+ elif self.test_pad_mode[0] in ['size_divisor']:
+ divisor = self.test_pad_mode[1]
+ target_h = int(np.ceil(h / divisor)) * divisor
+ target_w = int(np.ceil(w / divisor)) * divisor
+ else:
+ raise NotImplementedError(
+ 'RandomCenterCropPad only support two testing pad mode:'
+ 'logical-or and size_divisor.')
+
+ cropped_img, border, _ = self._crop_image_and_paste(
+ img, [h // 2, w // 2], [target_h, target_w])
+ results['img'] = cropped_img
+ results['pad_shape'] = cropped_img.shape
+ results['border'] = border
+ return results
+
+ def __call__(self, results):
+ img = results['img']
+ assert img.dtype == np.float32, (
+ 'RandomCenterCropPad needs the input image of dtype np.float32,'
+ ' please set "to_float32=True" in "LoadImageFromFile" pipeline')
+ h, w, c = img.shape
+ assert c == len(self.mean)
+ if self.test_mode:
+ return self._test_aug(results)
+ else:
+ return self._train_aug(results)
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(crop_size={self.crop_size}, '
+ repr_str += f'ratios={self.ratios}, '
+ repr_str += f'border={self.border}, '
+ repr_str += f'mean={self.input_mean}, '
+ repr_str += f'std={self.input_std}, '
+ repr_str += f'to_rgb={self.to_rgb}, '
+ repr_str += f'test_mode={self.test_mode}, '
+ repr_str += f'test_pad_mode={self.test_pad_mode}, '
+ repr_str += f'bbox_clip_border={self.bbox_clip_border})'
+ return repr_str
+
+
+@PIPELINES.register_module()
+class CutOut(object):
+ """CutOut operation.
+
+ Randomly drop some regions of image used in
+ `Cutout `_.
+
+ Args:
+ n_holes (int | tuple[int, int]): Number of regions to be dropped.
+ If it is given as a list, number of holes will be randomly
+ selected from the closed interval [`n_holes[0]`, `n_holes[1]`].
+ cutout_shape (tuple[int, int] | list[tuple[int, int]]): The candidate
+ shape of dropped regions. It can be `tuple[int, int]` to use a
+ fixed cutout shape, or `list[tuple[int, int]]` to randomly choose
+ shape from the list.
+ cutout_ratio (tuple[float, float] | list[tuple[float, float]]): The
+ candidate ratio of dropped regions. It can be `tuple[float, float]`
+ to use a fixed ratio or `list[tuple[float, float]]` to randomly
+ choose ratio from the list. Please note that `cutout_shape`
+ and `cutout_ratio` cannot be both given at the same time.
+ fill_in (tuple[float, float, float] | tuple[int, int, int]): The value
+ of pixel to fill in the dropped regions. Default: (0, 0, 0).
+ """
+
+ def __init__(self,
+ n_holes,
+ cutout_shape=None,
+ cutout_ratio=None,
+ fill_in=(0, 0, 0)):
+
+ assert (cutout_shape is None) ^ (cutout_ratio is None), \
+ 'Either cutout_shape or cutout_ratio should be specified.'
+ assert (isinstance(cutout_shape, (list, tuple))
+ or isinstance(cutout_ratio, (list, tuple)))
+ if isinstance(n_holes, tuple):
+ assert len(n_holes) == 2 and 0 <= n_holes[0] < n_holes[1]
+ else:
+ n_holes = (n_holes, n_holes)
+ self.n_holes = n_holes
+ self.fill_in = fill_in
+ self.with_ratio = cutout_ratio is not None
+ self.candidates = cutout_ratio if self.with_ratio else cutout_shape
+ if not isinstance(self.candidates, list):
+ self.candidates = [self.candidates]
+
+ def __call__(self, results):
+ """Call function to drop some regions of image."""
+ h, w, c = results['img'].shape
+ n_holes = np.random.randint(self.n_holes[0], self.n_holes[1] + 1)
+ for _ in range(n_holes):
+ x1 = np.random.randint(0, w)
+ y1 = np.random.randint(0, h)
+ index = np.random.randint(0, len(self.candidates))
+ if not self.with_ratio:
+ cutout_w, cutout_h = self.candidates[index]
+ else:
+ cutout_w = int(self.candidates[index][0] * w)
+ cutout_h = int(self.candidates[index][1] * h)
+
+ x2 = np.clip(x1 + cutout_w, 0, w)
+ y2 = np.clip(y1 + cutout_h, 0, h)
+ results['img'][y1:y2, x1:x2, :] = self.fill_in
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f'(n_holes={self.n_holes}, '
+ repr_str += (f'cutout_ratio={self.candidates}, ' if self.with_ratio
+ else f'cutout_shape={self.candidates}, ')
+ repr_str += f'fill_in={self.fill_in})'
+ return repr_str
diff --git a/walt/datasets/walt.py b/walt/datasets/walt.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ad458bf4c16ffb336313e9c9169a510034d2dd
--- /dev/null
+++ b/walt/datasets/walt.py
@@ -0,0 +1,875 @@
+import itertools
+import logging
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools
+from mmcv.utils import print_log
+from pycocotools.coco import COCO
+#from pycocotools.cocoeval import COCOeval
+from .cocoeval import COCOeval
+from terminaltables import AsciiTable
+
+from mmdet.core import eval_recalls
+from .builder import DATASETS
+from mmdet.datasets.custom import CustomDataset
+
+import imagesize
+from concurrent.futures import ProcessPoolExecutor
+import multiprocessing as mp
+from copy import deepcopy
+from tqdm import tqdm
+
+@DATASETS.register_module()
+class WaltDataset(CustomDataset):
+
+ CLASSES = ('vehicle', 'occluded_vehicle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
+ 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+ if not getattr(pycocotools, '__version__', '0') >= '12.0.2':
+ raise AssertionError(
+ 'Incompatible version of pycocotools is installed. '
+ 'Run pip uninstall pycocotools first. Then run pip '
+ 'install mmpycocotools to install open-mmlab forked '
+ 'pycocotools.')
+ import os.path
+ print(ann_file + 'ann.json')
+ if not os.path.exists(ann_file + 'ann.json'):
+ self.save_json(ann_file)
+
+ self.coco = COCO(ann_file + 'ann.json')
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
+
+ def save_json(self, ann_file):
+ import glob
+ import cv2
+ import time
+ data = {}
+
+ data["info"] = {
+ 'url': "https://www.andrew.cmu.edu/user/dnarapur/",
+ 'year': 2018,
+ 'date_created': time.strftime("%a, %d %b %Y %H:%M:%S +0000",
+ time.localtime()),
+ 'description': "This is a dataset for occlusion detection.",
+ 'version': '1.0',
+ 'contributor': 'CMU'}
+ data["categories"] = [{'name': 'car','id': 0,'supercategory': 'car'}]
+ data["licenses"] = [{'id': 1,
+ 'name': "unknown",
+ 'url': "unknown"}]
+ data["images"] = []
+ data["annotations"] = []
+
+
+ self.data_infs = []
+ self.ann_file = ann_file
+
+ count = 0
+ #for img_folder in sorted(glob.glob(ann_file + '/*')):
+ img_folder = ann_file
+ #print(img_folder + '/images/*', glob.glob(img_folder + '/images/*'))
+ for img_name in tqdm(sorted(glob.glob(img_folder + '/images/*')), desc="Converting CWALT to COCO format "):
+ cam_name = img_folder.split('/')[-1]
+ import imagesize
+ width, height = imagesize.get(img_name)
+ img_name = img_name.split('/')[-1] #.replace('.npz','.png')
+ info = dict(license=3, height=height, width=width, file_name = img_name, date_captured = img_name.split('/')[-1].split('.')[0], id = count, filename = img_name, camname = cam_name)
+ self.data_infs.append(info)
+
+ data["images"].append({'flickr_url': "unknown",
+ 'coco_url': "unknown",
+ #'file_name': cam_name+'/images/' +img_name,
+ 'file_name': 'images/' +img_name,
+ 'id': count,
+ 'license':1,
+ #'has_visible_keypoints':True,
+ 'date_captured': "unknown",
+ 'width': width,
+ 'height': height})
+ count = count+1
+ #if count<2 and count > 30:
+ #if count > 5:
+ # break
+ #break
+
+ obj_id = 0
+ #for img_folder in sorted(glob.glob(ann_file + '/*')):
+ img_folder = ann_file
+ with ProcessPoolExecutor(max_workers=mp.cpu_count()-1) as executor:
+ img_names = glob.glob(img_folder + '/images/*')
+ for ann_in, count in executor.map(self.get_ann_info_local, list(range(0, len(img_names)-1))):
+ #count = img_names.index(img_folder + '/images/'+ann_in['image_name'])
+ #print(ann_in['image_name'], count, img_names[count])
+
+ for loop in range(len(ann_in['bboxes'])):
+ bbox = ann_in['bboxes'][loop]
+ segmentation = ann_in['masks'][loop]
+
+ data["annotations"].append({
+ 'image_id': count,
+ 'category_id': 0,
+ 'iscrowd': 0,
+ 'occ_percentage': ann_in['occ_percentage'][loop],
+ 'id': obj_id,
+ 'area': int(bbox[2]*bbox[3]),
+ 'bbox': [int(bbox[0]), int(bbox[1]), int(bbox[2])-int(bbox[0]),int(bbox[3])-int(bbox[1])],
+ 'segmentation': [segmentation]
+ })
+ obj_id = obj_id + 1
+ '''
+
+
+ coco_kins=COCO('data/parking_real/kins/update_train_2020.json')
+ catIds = [1,2]#coco_kins.getCatIds(catNms=['car']);
+ imgIds = coco_kins.getImgIds(catIds=catIds );
+
+ count = 0
+ count_obj = 0
+ for id_1 in imgIds:
+ img = coco_kins.loadImgs(id_1)[0]
+
+ data["images"].append({'flickr_url': "unknown",
+ 'coco_url': "unknown",
+ #'file_name': cam_name+'/images/' +img_name,
+ 'file_name': '../kins/'+img['file_name'],
+ 'id': 1000000+count,
+ 'license':1,
+ #'has_visible_keypoints':True,
+ 'date_captured': "unknown",
+ 'width': img['width'],
+ 'height': img['height']})
+ annIds = coco_kins.getAnnIds(imgIds=id_1, catIds=catIds, iscrowd=None)
+ for id_2 in annIds:
+ ann = coco_kins.loadAnns(id_2)
+ data["annotations"].append({
+ 'image_id': 1000000+count,
+ 'category_id': 0,
+ 'iscrowd': 0,
+ 'occ_percentage': ann[0]['i_area']/ann[0]['a_area']*100,
+ 'id': 1000000+count_obj,
+ 'area': ann[0]['a_area'],
+ 'bbox': ann[0]['a_bbox'],
+ 'segmentation': [{'full':ann[0]['a_segm'],'visible':ann[0]['i_segm']}]
+ })
+ count_obj = count_obj+1
+ count= count+1
+ '''
+ '''
+
+ for img_folder in sorted(glob.glob(ann_file + '/*')):
+ for img_name in sorted(glob.glob(img_folder + '/images/*')):
+ #for img_folder in sorted(glob.glob(ann_file.replace('GT_data','images') + '/*')):
+ # for i in sorted(glob.glob(ann_file + '*')):
+ ann_in = self.get_ann_info_local(count)
+ for loop in range(len(ann_in['bboxes'])):
+ bbox = ann_in['bboxes'][loop]
+ segmentation = ann_in['masks'][loop]
+
+ data["annotations"].append({
+ 'image_id': count,
+ 'category_id': 0,
+ 'iscrowd': 0,
+ 'id': obj_id,
+ 'area': int(bbox[2]*bbox[3]),
+ 'bbox': [int(bbox[0]), int(bbox[1]), int(bbox[2])-int(bbox[0]),int(bbox[3])-int(bbox[1])],
+ 'segmentation': [segmentation]
+ })
+ obj_id = obj_id + 1
+ count = count+1
+ #if count<2 and count > 30:
+ #if count > 5:
+ # break
+ #break
+ '''
+ import json
+ json_str = json.dumps(data)
+ with open(ann_file + '/ann.json', 'w') as f:
+ f.write(json_str)
+
+ def get_ann_info_local(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ return self._parse_ann_info_local(idx)
+
+ def _parse_ann_info_local(self, idx):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ try:
+ img_info = self.data_infs[idx]
+ except:
+ img_info = self.data_infs[0]
+
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ gt_occ_percentage = []
+ import cv2
+ print(self.ann_file + '/Segmentation/' + img_info['filename'])
+
+ #seg_o = cv2.imread(self.ann_file + 'Segmentation' + img_info['filename'])
+
+
+ try:
+ seg_all = np.load(self.ann_file +img_info['camname']+ '/Segmentation/' + img_info['filename'].replace('jpg','npz'))
+ print(seg_all['mask'].shape)
+ for loop in range(seg_all['mask'].shape[0]):
+ seg_o = seg_all['mask'][loop]
+ segmentations_original, encoded_ground_truth_original, ground_truth_binary_mask_original = self.get_segmentation(seg_o, 1)
+ seg_o[seg_o>0] =1
+ segmentations, encoded_ground_truth, ground_truth_binary_mask = self.get_segmentation(seg_o, 1)
+
+ x1, y1, w, h = pycocotools.mask.toBbox(encoded_ground_truth)
+ x1_o, y1_o, w_o, h_o = pycocotools.mask.toBbox(encoded_ground_truth_original)
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ bbox = [x1, y1, x1 + w, y1 + h]
+ bbox_o = [x1_o, y1_o, x1_o + w_o, y1_o + h_o]
+ if len(segmentations_original) == 0:
+ continue
+ if w != w_o or h != h_o or len(np.unique(ground_truth_binary_mask-ground_truth_binary_mask_original)) >1:
+ #gt_masks_ann.append([segmentations_original, segmentations])
+ gt_masks_ann.append({'visible': segmentations_original,'full': segmentations})
+ gt_bboxes.append(bbox)
+ gt_labels.append(0)
+ gt_occ_percentage.append(100 - np.sum(ground_truth_binary_mask_original)/np.sum(ground_truth_binary_mask)*100)
+
+ else:
+ gt_masks_ann.append({'visible': segmentations,'full': segmentations})
+ gt_bboxes.append(bbox)
+ gt_labels.append(0)
+ gt_occ_percentage.append(0)
+
+ if inter_w * inter_h == 0:
+ continue
+
+
+
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+ seg_map = img_info['filename']
+ except:
+ print('annotations failed to load for' ,img_info['filename'])
+ if len(gt_bboxes) ==0 or gt_bboxes == []:
+ ann = self._parse_ann_info_local(idx+1)
+ print('annotations failed to load for' ,img_info['filename'])
+ return ann
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ occ_percentage=gt_occ_percentage,
+ seg_map=seg_map,
+ image_name=img_info['filename'])
+
+ return ann, idx
+
+ def get_segmentation(self, seg, idx):
+ ground_truth_binary_mask = seg.copy()*0
+ ground_truth_binary_mask[seg==idx] = 255
+ ground_truth_binary_mask = ground_truth_binary_mask[:,:,0]
+ fortran_ground_truth_binary_mask = np.asfortranarray(ground_truth_binary_mask)
+ encoded_ground_truth = pycocotools.mask.encode(fortran_ground_truth_binary_mask)
+ ground_truth_area = pycocotools.mask.area(encoded_ground_truth)
+ from skimage import measure
+ contours = measure.find_contours(ground_truth_binary_mask, 0.5)
+ segmentations = []
+ for contour in contours:
+ contour = np.flip(contour, axis=1)
+ segmentation = contour.ravel().tolist()
+ segmentations.append(segmentation)
+ return segmentations, encoded_ground_truth, ground_truth_binary_mask
+
+
+
+ def get_ann_info(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return self._parse_ann_info(self.data_infos[idx], ann_info)
+
+ def get_cat_ids(self, idx):
+ """Get COCO category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return [ann['category_id'] for ann in ann_info]
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = self.img_ids[i]
+ if self.filter_empty_gt and img_id not in ids_in_cat:
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ if inter_w * inter_h == 0:
+ continue
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ #bbox = [x1, y1, w, h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ #gt_masks_ann.append(ann.get('segmentation', None))
+ #print(ann.get('segmentation', None)[0].keys())
+ try:
+ gt_masks_ann.append({'visible': ann.get('segmentation', None)[0]['visible'],'full': ann.get('segmentation', None)[0]['full']})
+ except:
+ gt_masks_ann.append({'visible': ann.get('segmentation', None)[0]['visible']})
+
+
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].replace('jpg', 'png')
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def xyxy2xywh(self, bbox):
+ """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
+ evaluation.
+
+ Args:
+ bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
+ ``xyxy`` order.
+
+ Returns:
+ list[float]: The converted bounding boxes, in ``xywh`` order.
+ """
+
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0],
+ _bbox[3] - _bbox[1],
+ ]
+
+ def _proposal2json(self, results):
+ """Convert proposal results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+ def _det2json(self, results):
+ """Convert detection results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ result = results[idx]
+ for label in range(len(result)):
+ bboxes = result[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+ def _segm2json(self, results):
+ """Convert instance segmentation results to COCO json style."""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ det, seg = results[idx]
+ for label in range(len(det)):
+ # bbox results
+ bboxes = det[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ bbox_json_results.append(data)
+
+ # segm results
+ # some detectors use different scores for bbox and mask
+ if isinstance(seg, tuple):
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(mask_score[i])
+ data['category_id'] = self.cat_ids[label]
+ if isinstance(segms[i]['counts'], bytes):
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the detection results to a COCO style json file.
+
+ There are 3 types of results: proposals, bbox predictions, mask
+ predictions, and they have different data types. This method will
+ automatically recognize the type, and dump them to json files.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
+ "somepath/xxx.proposal.json".
+
+ Returns:
+ dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
+ values are corresponding filenames.
+ """
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = self._det2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = self._segm2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = self._proposal2json(results)
+ result_files['proposal'] = f'{outfile_prefix}.proposal.json'
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
+
+ def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
+ gt_bboxes = []
+ for i in range(len(self.img_ids)):
+ ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
+ ann_info = self.coco.load_anns(ann_ids)
+ if len(ann_info) == 0:
+ gt_bboxes.append(np.zeros((0, 4)))
+ continue
+ bboxes = []
+ for ann in ann_info:
+ if ann.get('ignore', False) or ann['iscrowd']:
+ continue
+ x1, y1, w, h = ann['bbox']
+ bboxes.append([x1, y1, x1 + w, y1 + h])
+ #bboxes.append([x1, y1, x1, y1])
+ bboxes = np.array(bboxes, dtype=np.float32)
+ if bboxes.shape[0] == 0:
+ bboxes = np.zeros((0, 4))
+ gt_bboxes.append(bboxes)
+
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
+ ar = recalls.mean(axis=1)
+ return ar
+
+ def format_results(self, results, jsonfile_prefix=None, **kwargs):
+ """Format the results to json (standard format for COCO evaluation).
+
+ Args:
+ results (list[tuple | numpy.ndarray]): Testing results of the
+ dataset.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving json files when jsonfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ #jsonfile_prefix = osp.join('./', 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Evaluation in COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ if iou_thrs is None:
+ iou_thrs = np.linspace(
+ .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ if metric_items is not None:
+ if not isinstance(metric_items, list):
+ metric_items = [metric_items]
+
+ result_files_all, tmp_dir = self.format_results(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ cocoGt_all = self.coco
+ '''
+ for loop, an in enumerate(cocoGt.anns):
+ try:
+ cocoGt.anns[loop]['segmentation'] = cocoGt.anns[loop]['segmentation'][0]['full']
+ except:
+ continue
+ '''
+ results_all =[]
+
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files_all:
+ raise KeyError(f'{metric} is not in results')
+
+ '''
+ try:
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+ '''
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ '''
+ import pickle
+ with open(f'testGt.pickle','wb') as file:
+ pickle.dump(cocoGt, file)
+ with open(f'testDt.pickle','wb') as file1:
+ pickle.dump(cocoDt, file1)
+ np.savez('data_cocoeval', cocoGt=cocoGt,cocoDt=cocoDt, iou_type=iou_type)
+ '''
+ cocoGt = deepcopy(cocoGt_all)
+ result_files = deepcopy(result_files_all)
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ for ind, d in enumerate(cocoGt.anns):
+ cocoGt.anns[ind]['segmentation'] = cocoGt.anns[ind]['segmentation'][0]['full']
+ for i in range(11):
+ i = i-1
+ cocoGt = deepcopy(cocoGt_all)
+ result_files = deepcopy(result_files_all)
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ for ind, d in enumerate(cocoGt.anns):
+ cocoGt.anns[ind]['segmentation'] = cocoGt.anns[ind]['segmentation'][0]['full']
+ cocoEval = COCOeval(cocoGt, cocoDt, metric)
+ cocoEval.percentage_occ = i
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ str123 = '{:s}'.format(' '.join(['{:.2f}'.format(x) for x in cocoEval.stats]))
+ results_all.append(str123 + ' '+str(metric)+' '+ str(i))
+ np.savetxt('results.out', results_all, delimiter=',', fmt="%s")
+ '''
+ '''
+
+ cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
+ #cocoEval = COCOeval(cocoGt, cocoDt, 'asas')
+ cocoEval.params.catIds = self.cat_ids
+ cocoEval.params.imgIds = self.img_ids
+ cocoEval.params.maxDets = list(proposal_nums)
+ cocoEval.params.iouThrs = iou_thrs
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item {metric_item} is not supported')
+ '''
+ with open(f'cocoEval.pickle','wb') as file1:
+ pickle.dump(cocoEval, file1)
+ '''
+ if metric == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
+ eval_results[item] = val
+ else:
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = cocoEval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.loadCats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
+ )
+ eval_results[key] = val
+ ap = cocoEval.stats[:6]
+ eval_results[f'{metric}_mAP_copypaste'] = (
+ f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
+ f'{ap[4]:.3f} {ap[5]:.3f}')
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/walt/datasets/walt_3d.py b/walt/datasets/walt_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..140afe8b9e3b69a5254e6ea4721aed546f83ddc7
--- /dev/null
+++ b/walt/datasets/walt_3d.py
@@ -0,0 +1,535 @@
+import itertools
+import logging
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools
+from mmcv.utils import print_log
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from terminaltables import AsciiTable
+
+from mmdet.core import eval_recalls
+from .builder import DATASETS
+from .custom import CustomDatasetLocal
+
+
+def bounding_box(points):
+ """returns a list containing the bottom left and the top right
+ points in the sequence
+ Here, we traverse the collection of points only once,
+ to find the min and max for x and y
+ """
+ bot_left_x, bot_left_y = float('inf'), float('inf')
+ top_right_x, top_right_y = float('-inf'), float('-inf')
+ for point in points:
+ x = point[0]
+ y = point[1]
+ if x < 0 or y < 0:
+ continue
+ bot_left_x = min(bot_left_x, x)
+ bot_left_y = min(bot_left_y, y)
+ top_right_x = max(top_right_x, x)
+ top_right_y = max(top_right_y, y)
+
+ return [bot_left_x, bot_left_y, top_right_x, top_right_y]
+
+lines = [[0,1],[1,3],[0,2],[3,2],[0,4],[1,5],[2,6],[3,7],[4,5],[5,7],[4,6],[7,6]]
+
+def get_boundingbox2d3d(cameraname, gt_data, extrinsics_path):
+ f = open(extrinsics_path,"r")
+ while True:
+ a = f.readline()
+ print(cameraname, a.split('\n')[0].split(' ')[0])
+ if cameraname in a.split('\n')[0].split(' ')[0]:
+ a = a.split('\n')[0].split(' ')
+ break
+
+ K = np.reshape(np.array(a[1:10]),[3,3]).astype(float)
+ R = np.reshape(a[10:19], [3,3])
+ T = np.array([[a[19]],[a[20]],[a[21]]])
+ RT = np.hstack((R,T)).astype(float)
+ KRT = np.matmul(K, RT)
+ bb_3d_connected = []
+ bb_3d_all = []
+ bb_2d_all = []
+ bb_3d_proj_all = []
+
+ for indice, keypoints_3d in enumerate(gt_data['arr_0'][1]):
+ parking_space = gt_data['arr_0'][0][indice][0]
+
+ if gt_data['arr_0'][0][indice][1] == 0:
+ continue
+ points2d_all = []
+ parking_space = np.vstack([parking_space, parking_space+[0,0,2]])
+ parking_space_tranformed = []
+ for point in parking_space:
+ point = [point[0], point[1], point[2], 1]
+ point = np.matmul(RT, point)
+ parking_space_tranformed.append(list(point))
+ point2d = np.matmul(K, point)
+ if point2d[2] < 0:
+ points2d_all.append([-100,-100,1])
+ continue
+ point2d = point2d/point2d[2]
+ if point2d[0] < 0 or point2d[0] >2048:
+ points2d_all.append([-100,-100,1])
+ continue
+ if point2d[1] < 0 or point2d[1] >2048:
+ points2d_all.append([-100,-100,1])
+ continue
+
+ points2d_all.append(point2d)
+
+ bb_3d_proj_all.append(points2d_all)
+ bbox = bounding_box(points2d_all)
+ if float('inf') in bbox:
+ continue
+ bb_2d_all.append(bbox)
+ bb_3d_all.append(parking_space)
+ #for line in lines:
+ # bb_3d_connected.append(parking_space[line[0]])
+ # bb_3d_connected.append(parking_space[line[1]])
+ #asas
+ return bb_3d_all, bb_2d_all, bb_3d_proj_all
+
+
+@DATASETS.register_module()
+class Walt3DDataset(CustomDatasetLocal):
+
+ CLASSES = ('person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
+ 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+ def load_annotations(self, ann_file):
+ import glob
+ count = 0
+ data_infos = []
+ self.data_annotations = []
+ for i in glob.glob(ann_file + '*'):
+ gt_data = np.load(i , allow_pickle = True)
+ for img_folder in glob.glob(ann_file.replace('GT_data','images') + '/*'):
+ cam_name = img_folder.split('/')[-1]
+ img_name = i.split('/')[-1].replace('.npz','.png')
+ info = dict(license=3, height=2048, width=2048, file_name = cam_name+'/' + img_name, date_captured = i.split('/')[-1].split('.')[0], id = count, filename = cam_name+'/' + img_name)
+
+ #info = dict(license=3, height=2048, width=2048, file_name = i.split('/')[-1].replace('.npz','.png'), date_captured = i.split('/')[-1].split('.')[0], id = count, filename = i.split('/')[-1].replace('.npz','.png'))
+ count = count+1
+ data_infos.append(info)
+ bb_3d_all, bb_2d_all, bb_3d_proj_all = get_boundingbox2d3d(cam_name, gt_data, ann_file.replace('GT_data','Extrinsics') + '/frame_par.txt')
+ self.data_annotations.append([bb_3d_all, bb_2d_all, bb_3d_proj_all])
+ break
+ return data_infos
+
+
+ def get_ann_info(self, idx):
+ data = self.data_annotations[idx]
+ gt_bboxes = np.array(data[1])
+ gt_bboxes_3d = np.array(data[0])
+ gt_bboxes_3d_proj = np.array(data[2])
+
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ bboxes_3d = gt_bboxes_3d,
+ bboxes_3d_proj = gt_bboxes_3d_proj,
+ labels = (np.zeros(len(gt_bboxes))+2).astype(int),
+ bboxes_ignore=np.zeros((0, 4), dtype=np.float32),
+ #masks=np.array([]),
+ seg_map=np.array([]))
+ return ann
+
+ def get_cat_ids(self, idx):
+ data = self.data_annotations[idx]
+ gt_bboxes = np.array(data[1])
+ return (np.zeros(len(gt_bboxes))+2).astype(int)
+
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ for data_info in self.data_infos:
+ valid_inds.append(data_info['id'])
+ print(valid_inds)
+
+ return valid_inds
+
+
+ def xyxy2xywh(self, bbox):
+ """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
+ evaluation.
+
+ Args:
+ bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
+ ``xyxy`` order.
+
+ Returns:
+ list[float]: The converted bounding boxes, in ``xywh`` order.
+ """
+
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0],
+ _bbox[3] - _bbox[1],
+ ]
+
+ def _proposal2json(self, results):
+ """Convert proposal results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+ def _det2json(self, results):
+ """Convert detection results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ result = results[idx]
+ for label in range(len(result)):
+ bboxes = result[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+ def _segm2json(self, results):
+ """Convert instance segmentation results to COCO json style."""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ det, seg = results[idx]
+ for label in range(len(det)):
+ # bbox results
+ bboxes = det[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ bbox_json_results.append(data)
+
+ # segm results
+ # some detectors use different scores for bbox and mask
+ if isinstance(seg, tuple):
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(mask_score[i])
+ data['category_id'] = self.cat_ids[label]
+ if isinstance(segms[i]['counts'], bytes):
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the detection results to a COCO style json file.
+
+ There are 3 types of results: proposals, bbox predictions, mask
+ predictions, and they have different data types. This method will
+ automatically recognize the type, and dump them to json files.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
+ "somepath/xxx.proposal.json".
+
+ Returns:
+ dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
+ values are corresponding filenames.
+ """
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = self._det2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = self._segm2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = self._proposal2json(results)
+ result_files['proposal'] = f'{outfile_prefix}.proposal.json'
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
+
+ def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
+ gt_bboxes = []
+ for i in range(len(self.img_ids)):
+ ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
+ ann_info = self.coco.load_anns(ann_ids)
+ if len(ann_info) == 0:
+ gt_bboxes.append(np.zeros((0, 4)))
+ continue
+ bboxes = []
+ for ann in ann_info:
+ if ann.get('ignore', False) or ann['iscrowd']:
+ continue
+ x1, y1, w, h = ann['bbox']
+ bboxes.append([x1, y1, x1 + w, y1 + h])
+ bboxes = np.array(bboxes, dtype=np.float32)
+ if bboxes.shape[0] == 0:
+ bboxes = np.zeros((0, 4))
+ gt_bboxes.append(bboxes)
+
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
+ ar = recalls.mean(axis=1)
+ return ar
+
+ def format_results(self, results, jsonfile_prefix=None, **kwargs):
+ """Format the results to json (standard format for COCO evaluation).
+
+ Args:
+ results (list[tuple | numpy.ndarray]): Testing results of the
+ dataset.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving json files when jsonfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Evaluation in COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ if iou_thrs is None:
+ iou_thrs = np.linspace(
+ .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ if metric_items is not None:
+ if not isinstance(metric_items, list):
+ metric_items = [metric_items]
+
+ result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ cocoGt = self.coco
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files:
+ raise KeyError(f'{metric} is not in results')
+ try:
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
+ cocoEval.params.catIds = self.cat_ids
+ cocoEval.params.imgIds = self.img_ids
+ cocoEval.params.maxDets = list(proposal_nums)
+ cocoEval.params.iouThrs = iou_thrs
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item {metric_item} is not supported')
+
+ if metric == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
+ eval_results[item] = val
+ else:
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = cocoEval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.loadCats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
+ )
+ eval_results[key] = val
+ ap = cocoEval.stats[:6]
+ eval_results[f'{metric}_mAP_copypaste'] = (
+ f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
+ f'{ap[4]:.3f} {ap[5]:.3f}')
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/walt/datasets/walt_synthetic.py b/walt/datasets/walt_synthetic.py
new file mode 100644
index 0000000000000000000000000000000000000000..df658be8e1c49b31f8d21ffae8416ce1f0ad650e
--- /dev/null
+++ b/walt/datasets/walt_synthetic.py
@@ -0,0 +1,781 @@
+import itertools
+import logging
+import os.path as osp
+import tempfile
+from collections import OrderedDict
+
+import mmcv
+import numpy as np
+import pycocotools
+from mmcv.utils import print_log
+from pycocotools.coco import COCO
+from pycocotools.cocoeval import COCOeval
+from terminaltables import AsciiTable
+
+from mmdet.core import eval_recalls
+from .builder import DATASETS
+from mmdet.datasets.custom import CustomDataset
+
+
+@DATASETS.register_module()
+class WaltSynthDataset(CustomDataset):
+
+ CLASSES = ('vehicle', 'occluded_vehicle', 'car', 'motorcycle', 'airplane', 'bus',
+ 'train', 'truck', 'boat', 'traffic light', 'fire hydrant',
+ 'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog',
+ 'horse', 'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe',
+ 'backpack', 'umbrella', 'handbag', 'tie', 'suitcase', 'frisbee',
+ 'skis', 'snowboard', 'sports ball', 'kite', 'baseball bat',
+ 'baseball glove', 'skateboard', 'surfboard', 'tennis racket',
+ 'bottle', 'wine glass', 'cup', 'fork', 'knife', 'spoon', 'bowl',
+ 'banana', 'apple', 'sandwich', 'orange', 'broccoli', 'carrot',
+ 'hot dog', 'pizza', 'donut', 'cake', 'chair', 'couch',
+ 'potted plant', 'bed', 'dining table', 'toilet', 'tv', 'laptop',
+ 'mouse', 'remote', 'keyboard', 'cell phone', 'microwave',
+ 'oven', 'toaster', 'sink', 'refrigerator', 'book', 'clock',
+ 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush')
+
+ def load_annotations(self, ann_file):
+ """Load annotation from COCO style annotation file.
+
+ Args:
+ ann_file (str): Path of annotation file.
+
+ Returns:
+ list[dict]: Annotation info from COCO api.
+ """
+ if not getattr(pycocotools, '__version__', '0') >= '12.0.2':
+ raise AssertionError(
+ 'Incompatible version of pycocotools is installed. '
+ 'Run pip uninstall pycocotools first. Then run pip '
+ 'install mmpycocotools to install open-mmlab forked '
+ 'pycocotools.')
+ import os.path
+ if not os.path.exists(ann_file + 'ann.json'):
+ self.save_json(ann_file)
+
+ self.coco = COCO(ann_file + 'ann.json')
+ self.cat_ids = self.coco.get_cat_ids(cat_names=self.CLASSES)
+ self.cat2label = {cat_id: i for i, cat_id in enumerate(self.cat_ids)}
+ self.img_ids = self.coco.get_img_ids()
+ data_infos = []
+ total_ann_ids = []
+ for i in self.img_ids:
+ info = self.coco.load_imgs([i])[0]
+ info['filename'] = info['file_name']
+ data_infos.append(info)
+ ann_ids = self.coco.get_ann_ids(img_ids=[i])
+ total_ann_ids.extend(ann_ids)
+ assert len(set(total_ann_ids)) == len(
+ total_ann_ids), f"Annotation ids in '{ann_file}' are not unique!"
+ return data_infos
+
+
+ def save_json(self, ann_file):
+ import glob
+ import cv2
+ import time
+ data = {}
+
+ data["info"] = {
+ 'url': "https://www.andrew.cmu.edu/user/dnarapur/",
+ 'year': 2018,
+ 'date_created': time.strftime("%a, %d %b %Y %H:%M:%S +0000",
+ time.localtime()),
+ 'description': "This is a dataset for occlusion detection.",
+ 'version': '1.0',
+ 'contributor': 'CMU'}
+ data["categories"] = [{'name': 'car','id': 0,'supercategory': 'car'}]
+ data["licenses"] = [{'id': 1,
+ 'name': "unknown",
+ 'url': "unknown"}]
+ data["images"] = []
+ data["annotations"] = []
+
+
+ self.data_infs = []
+ self.ann_file = ann_file
+
+ count = 0
+ for img_folder in sorted(glob.glob(ann_file.replace('GT_data','images') + '/*')):
+ for i in sorted(glob.glob(ann_file + '*')):
+ cam_name = img_folder.split('/')[-1]
+ img_name = i.split('/')[-1].replace('.npz','.png')
+ info = dict(license=3, height=512, width=512, file_name = cam_name+'/' + img_name, date_captured = i.split('/')[-1].split('.')[0], id = count, filename = cam_name+'/' + img_name)
+ self.data_infs.append(info)
+
+ data["images"].append({'flickr_url': "unknown",
+ 'coco_url': "unknown",
+ 'file_name': cam_name+'/' +img_name,
+ 'id': count,
+ 'license':1,
+ #'has_visible_keypoints':True,
+ 'date_captured': "unknown",
+ 'width': 512,
+ 'height': 512})
+ count = count+1
+ #if count<2 and count > 30:
+ #if count > 30:
+ # break
+ #break
+
+ count = 0
+ obj_id = 0
+ for img_folder in sorted(glob.glob(ann_file.replace('GT_data','images') + '/*')):
+ for i in sorted(glob.glob(ann_file + '*')):
+ ann_in = self.get_ann_info_local(count)
+ for loop in range(len(ann_in['bboxes'])):
+ bbox = ann_in['bboxes'][loop]
+ segmentation = ann_in['masks'][loop]
+
+ data["annotations"].append({
+ 'image_id': count,
+ 'category_id': 0,
+ 'iscrowd': 0,
+ 'id': obj_id,
+ 'area': int(bbox[2]*bbox[3]),
+ 'bbox': [int(bbox[0]), int(bbox[1]), int(bbox[2])-int(bbox[0]),int(bbox[3])-int(bbox[1])],
+ 'segmentation': [segmentation]
+ })
+ obj_id = obj_id + 1
+ count = count+1
+ #if count<2 and count > 30:
+ #if count > 30:
+ # break
+ #break
+ import json
+
+ json_str = json.dumps(data)
+ with open(ann_file + '/ann.json', 'w') as f:
+ f.write(json_str)
+
+ def get_ann_info_local(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+ return self._parse_ann_info_local(idx)
+
+ def _parse_ann_info_local(self, idx):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ try:
+ img_info = self.data_infs[idx]
+ except:
+ img_info = self.data_infs[0]
+
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ import cv2
+
+ seg_o = cv2.imread(self.ann_file.replace('GT_data','Segmentation') + img_info['filename'])
+
+ seg_prev = seg_o
+ seg_next = seg_o
+ seg = seg_o
+
+ #if len(np.unique(seg)) != 3:
+ # return self._parse_ann_info(idx+1)
+
+ if len(np.unique(seg)) == 3:
+ try:
+ seg_prev = cv2.imread(self.ann_file.replace('GT_data','Segmentation') + self.data_infs[idx-1]['filename'])
+ except:
+ print('prev not found')
+ try:
+ seg_next = cv2.imread(self.ann_file.replace('GT_data','Segmentation') + self.data_infs[idx+1]['filename'])
+ except:
+ print('next not found')
+
+ try:
+ for i in np.unique(seg_o):
+ if i ==0:
+ continue
+ if i in np.unique(seg_prev):
+ seg = seg_prev
+ if i in np.unique(seg_next):
+ seg = seg_next
+ segmentations, encoded_ground_truth, ground_truth_binary_mask = self.get_segmentation(seg, i)
+ segmentations_original, encoded_ground_truth_original, ground_truth_binary_mask_original = self.get_segmentation(seg_o, i)
+
+ x1, y1, w, h = pycocotools.mask.toBbox(encoded_ground_truth)
+ x1_o, y1_o, w_o, h_o = pycocotools.mask.toBbox(encoded_ground_truth_original)
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ bbox = [x1, y1, x1 + w, y1 + h]
+ #bbox = [x1, y1, w, h]
+ bbox_o = [x1_o, y1_o, x1_o + w_o, y1_o + h_o]
+ if w != w_o or h != h_o or len(np.unique(ground_truth_binary_mask-ground_truth_binary_mask_original)) >1:
+ #gt_masks_ann.append([segmentations_original, segmentations])
+ gt_masks_ann.append({'visible': segmentations_original,'full': segmentations})
+ gt_bboxes.append(bbox)
+ gt_labels.append(0)
+
+ #gt_masks_ann.append(segmentations_original)
+ #gt_bboxes.append(bbox_o)
+ #gt_labels.append(1)
+ else:
+ gt_masks_ann.append({'visible': segmentations,'full': segmentations})
+ gt_bboxes.append(bbox)
+ gt_labels.append(0)
+
+ if inter_w * inter_h == 0:
+ continue
+ #gt_bboxes.append(bbox)
+ #gt_labels.append(0)
+ #gt_masks_ann.append(segmentations)
+
+
+
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+ seg_map = img_info['filename']
+ except:
+ print('annotations failed to load for' ,img_info['filename'])
+ if len(gt_bboxes) ==0 or gt_bboxes == []:
+ ann = self._parse_ann_info_local(idx+1)
+ print('annotations failed to load for' ,img_info['filename'])
+ return ann
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def get_segmentation(self, seg, idx):
+ ground_truth_binary_mask = seg.copy()*0
+ ground_truth_binary_mask[seg==idx] = 255
+ ground_truth_binary_mask = ground_truth_binary_mask[:,:,0]
+ fortran_ground_truth_binary_mask = np.asfortranarray(ground_truth_binary_mask)
+ encoded_ground_truth = pycocotools.mask.encode(fortran_ground_truth_binary_mask)
+ ground_truth_area = pycocotools.mask.area(encoded_ground_truth)
+ from skimage import measure
+ contours = measure.find_contours(ground_truth_binary_mask, 0.5)
+ segmentations = []
+ for contour in contours:
+ contour = np.flip(contour, axis=1)
+ segmentation = contour.ravel().tolist()
+ segmentations.append(segmentation)
+ return segmentations, encoded_ground_truth, ground_truth_binary_mask
+
+
+
+ def get_ann_info(self, idx):
+ """Get COCO annotation by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ dict: Annotation info of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return self._parse_ann_info(self.data_infos[idx], ann_info)
+
+ def get_cat_ids(self, idx):
+ """Get COCO category ids by index.
+
+ Args:
+ idx (int): Index of data.
+
+ Returns:
+ list[int]: All categories in the image of specified index.
+ """
+
+ img_id = self.data_infos[idx]['id']
+ ann_ids = self.coco.get_ann_ids(img_ids=[img_id])
+ ann_info = self.coco.load_anns(ann_ids)
+ return [ann['category_id'] for ann in ann_info]
+
+ def _filter_imgs(self, min_size=32):
+ """Filter images too small or without ground truths."""
+ valid_inds = []
+ # obtain images that contain annotation
+ ids_with_ann = set(_['image_id'] for _ in self.coco.anns.values())
+ # obtain images that contain annotations of the required categories
+ ids_in_cat = set()
+ for i, class_id in enumerate(self.cat_ids):
+ ids_in_cat |= set(self.coco.cat_img_map[class_id])
+ # merge the image id sets of the two conditions and use the merged set
+ # to filter out images if self.filter_empty_gt=True
+ ids_in_cat &= ids_with_ann
+
+ valid_img_ids = []
+ for i, img_info in enumerate(self.data_infos):
+ img_id = self.img_ids[i]
+ if self.filter_empty_gt and img_id not in ids_in_cat:
+ continue
+ if min(img_info['width'], img_info['height']) >= min_size:
+ valid_inds.append(i)
+ valid_img_ids.append(img_id)
+ self.img_ids = valid_img_ids
+ return valid_inds
+
+ def _parse_ann_info(self, img_info, ann_info):
+ """Parse bbox and mask annotation.
+
+ Args:
+ ann_info (list[dict]): Annotation info of an image.
+ with_mask (bool): Whether to parse mask annotations.
+
+ Returns:
+ dict: A dict containing the following keys: bboxes, bboxes_ignore,\
+ labels, masks, seg_map. "masks" are raw annotations and not \
+ decoded into binary masks.
+ """
+ gt_bboxes = []
+ gt_labels = []
+ gt_bboxes_ignore = []
+ gt_masks_ann = []
+ for i, ann in enumerate(ann_info):
+ if ann.get('ignore', False):
+ continue
+ x1, y1, w, h = ann['bbox']
+ inter_w = max(0, min(x1 + w, img_info['width']) - max(x1, 0))
+ inter_h = max(0, min(y1 + h, img_info['height']) - max(y1, 0))
+ if inter_w * inter_h == 0:
+ continue
+ if ann['area'] <= 0 or w < 1 or h < 1:
+ continue
+ if ann['category_id'] not in self.cat_ids:
+ continue
+ bbox = [x1, y1, x1 + w, y1 + h]
+ #bbox = [x1, y1, w, h]
+ if ann.get('iscrowd', False):
+ gt_bboxes_ignore.append(bbox)
+ else:
+ gt_bboxes.append(bbox)
+ gt_labels.append(self.cat2label[ann['category_id']])
+ #gt_masks_ann.append(ann.get('segmentation', None))
+ #print(ann.get('segmentation', None)[0].keys())
+ try:
+ gt_masks_ann.append({'visible': ann.get('segmentation', None)[0]['visible'],'full': ann.get('segmentation', None)[0]['full']})
+ except:
+ gt_masks_ann.append({'visible': ann.get('segmentation', None)[0]['visible']})
+
+
+
+ if gt_bboxes:
+ gt_bboxes = np.array(gt_bboxes, dtype=np.float32)
+ gt_labels = np.array(gt_labels, dtype=np.int64)
+ else:
+ gt_bboxes = np.zeros((0, 4), dtype=np.float32)
+ gt_labels = np.array([], dtype=np.int64)
+
+ if gt_bboxes_ignore:
+ gt_bboxes_ignore = np.array(gt_bboxes_ignore, dtype=np.float32)
+ else:
+ gt_bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
+
+ seg_map = img_info['filename'].replace('jpg', 'png')
+
+ ann = dict(
+ bboxes=gt_bboxes,
+ labels=gt_labels,
+ bboxes_ignore=gt_bboxes_ignore,
+ masks=gt_masks_ann,
+ seg_map=seg_map)
+
+ return ann
+
+ def xyxy2xywh(self, bbox):
+ """Convert ``xyxy`` style bounding boxes to ``xywh`` style for COCO
+ evaluation.
+
+ Args:
+ bbox (numpy.ndarray): The bounding boxes, shape (4, ), in
+ ``xyxy`` order.
+
+ Returns:
+ list[float]: The converted bounding boxes, in ``xywh`` order.
+ """
+
+ _bbox = bbox.tolist()
+ return [
+ _bbox[0],
+ _bbox[1],
+ _bbox[2] - _bbox[0],
+ _bbox[3] - _bbox[1],
+ ]
+
+ def _proposal2json(self, results):
+ """Convert proposal results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ bboxes = results[idx]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = 1
+ json_results.append(data)
+ return json_results
+
+ def _det2json(self, results):
+ """Convert detection results to COCO json style."""
+ json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ result = results[idx]
+ for label in range(len(result)):
+ bboxes = result[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ json_results.append(data)
+ return json_results
+
+ def _segm2json(self, results):
+ """Convert instance segmentation results to COCO json style."""
+ bbox_json_results = []
+ segm_json_results = []
+ for idx in range(len(self)):
+ img_id = self.img_ids[idx]
+ det, seg = results[idx]
+ for label in range(len(det)):
+ # bbox results
+ bboxes = det[label]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(bboxes[i][4])
+ data['category_id'] = self.cat_ids[label]
+ bbox_json_results.append(data)
+
+ # segm results
+ # some detectors use different scores for bbox and mask
+ if isinstance(seg, tuple):
+ segms = seg[0][label]
+ mask_score = seg[1][label]
+ else:
+ segms = seg[label]
+ mask_score = [bbox[4] for bbox in bboxes]
+ for i in range(bboxes.shape[0]):
+ data = dict()
+ data['image_id'] = img_id
+ data['bbox'] = self.xyxy2xywh(bboxes[i])
+ data['score'] = float(mask_score[i])
+ data['category_id'] = self.cat_ids[label]
+ if isinstance(segms[i]['counts'], bytes):
+ segms[i]['counts'] = segms[i]['counts'].decode()
+ data['segmentation'] = segms[i]
+ segm_json_results.append(data)
+ return bbox_json_results, segm_json_results
+
+ def results2json(self, results, outfile_prefix):
+ """Dump the detection results to a COCO style json file.
+
+ There are 3 types of results: proposals, bbox predictions, mask
+ predictions, and they have different data types. This method will
+ automatically recognize the type, and dump them to json files.
+
+ Args:
+ results (list[list | tuple | ndarray]): Testing results of the
+ dataset.
+ outfile_prefix (str): The filename prefix of the json files. If the
+ prefix is "somepath/xxx", the json files will be named
+ "somepath/xxx.bbox.json", "somepath/xxx.segm.json",
+ "somepath/xxx.proposal.json".
+
+ Returns:
+ dict[str: str]: Possible keys are "bbox", "segm", "proposal", and \
+ values are corresponding filenames.
+ """
+ result_files = dict()
+ if isinstance(results[0], list):
+ json_results = self._det2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ mmcv.dump(json_results, result_files['bbox'])
+ elif isinstance(results[0], tuple):
+ json_results = self._segm2json(results)
+ result_files['bbox'] = f'{outfile_prefix}.bbox.json'
+ result_files['proposal'] = f'{outfile_prefix}.bbox.json'
+ result_files['segm'] = f'{outfile_prefix}.segm.json'
+ mmcv.dump(json_results[0], result_files['bbox'])
+ mmcv.dump(json_results[1], result_files['segm'])
+ elif isinstance(results[0], np.ndarray):
+ json_results = self._proposal2json(results)
+ result_files['proposal'] = f'{outfile_prefix}.proposal.json'
+ mmcv.dump(json_results, result_files['proposal'])
+ else:
+ raise TypeError('invalid type of results')
+ return result_files
+
+ def fast_eval_recall(self, results, proposal_nums, iou_thrs, logger=None):
+ gt_bboxes = []
+ for i in range(len(self.img_ids)):
+ ann_ids = self.coco.get_ann_ids(img_ids=self.img_ids[i])
+ ann_info = self.coco.load_anns(ann_ids)
+ if len(ann_info) == 0:
+ gt_bboxes.append(np.zeros((0, 4)))
+ continue
+ bboxes = []
+ for ann in ann_info:
+ if ann.get('ignore', False) or ann['iscrowd']:
+ continue
+ x1, y1, w, h = ann['bbox']
+ bboxes.append([x1, y1, x1 + w, y1 + h])
+ #bboxes.append([x1, y1, x1, y1])
+ bboxes = np.array(bboxes, dtype=np.float32)
+ if bboxes.shape[0] == 0:
+ bboxes = np.zeros((0, 4))
+ gt_bboxes.append(bboxes)
+
+ recalls = eval_recalls(
+ gt_bboxes, results, proposal_nums, iou_thrs, logger=logger)
+ ar = recalls.mean(axis=1)
+ return ar
+
+ def format_results(self, results, jsonfile_prefix=None, **kwargs):
+ """Format the results to json (standard format for COCO evaluation).
+
+ Args:
+ results (list[tuple | numpy.ndarray]): Testing results of the
+ dataset.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+
+ Returns:
+ tuple: (result_files, tmp_dir), result_files is a dict containing \
+ the json filepaths, tmp_dir is the temporal directory created \
+ for saving json files when jsonfile_prefix is not specified.
+ """
+ assert isinstance(results, list), 'results must be a list'
+ assert len(results) == len(self), (
+ 'The length of results is not equal to the dataset len: {} != {}'.
+ format(len(results), len(self)))
+
+ if jsonfile_prefix is None:
+ tmp_dir = tempfile.TemporaryDirectory()
+ jsonfile_prefix = osp.join(tmp_dir.name, 'results')
+ #jsonfile_prefix = osp.join('./', 'results')
+ else:
+ tmp_dir = None
+ result_files = self.results2json(results, jsonfile_prefix)
+ return result_files, tmp_dir
+
+ def evaluate(self,
+ results,
+ metric='bbox',
+ logger=None,
+ jsonfile_prefix=None,
+ classwise=False,
+ proposal_nums=(100, 300, 1000),
+ iou_thrs=None,
+ metric_items=None):
+ """Evaluation in COCO protocol.
+
+ Args:
+ results (list[list | tuple]): Testing results of the dataset.
+ metric (str | list[str]): Metrics to be evaluated. Options are
+ 'bbox', 'segm', 'proposal', 'proposal_fast'.
+ logger (logging.Logger | str | None): Logger used for printing
+ related information during evaluation. Default: None.
+ jsonfile_prefix (str | None): The prefix of json files. It includes
+ the file path and the prefix of filename, e.g., "a/b/prefix".
+ If not specified, a temp file will be created. Default: None.
+ classwise (bool): Whether to evaluating the AP for each class.
+ proposal_nums (Sequence[int]): Proposal number used for evaluating
+ recalls, such as recall@100, recall@1000.
+ Default: (100, 300, 1000).
+ iou_thrs (Sequence[float], optional): IoU threshold used for
+ evaluating recalls/mAPs. If set to a list, the average of all
+ IoUs will also be computed. If not specified, [0.50, 0.55,
+ 0.60, 0.65, 0.70, 0.75, 0.80, 0.85, 0.90, 0.95] will be used.
+ Default: None.
+ metric_items (list[str] | str, optional): Metric items that will
+ be returned. If not specified, ``['AR@100', 'AR@300',
+ 'AR@1000', 'AR_s@1000', 'AR_m@1000', 'AR_l@1000' ]`` will be
+ used when ``metric=='proposal'``, ``['mAP', 'mAP_50', 'mAP_75',
+ 'mAP_s', 'mAP_m', 'mAP_l']`` will be used when
+ ``metric=='bbox' or metric=='segm'``.
+
+ Returns:
+ dict[str, float]: COCO style evaluation metric.
+ """
+
+ metrics = metric if isinstance(metric, list) else [metric]
+ allowed_metrics = ['bbox', 'segm', 'proposal', 'proposal_fast']
+ for metric in metrics:
+ if metric not in allowed_metrics:
+ raise KeyError(f'metric {metric} is not supported')
+ if iou_thrs is None:
+ iou_thrs = np.linspace(
+ .5, 0.95, int(np.round((0.95 - .5) / .05)) + 1, endpoint=True)
+ if metric_items is not None:
+ if not isinstance(metric_items, list):
+ metric_items = [metric_items]
+
+ result_files, tmp_dir = self.format_results(results, jsonfile_prefix)
+
+ eval_results = OrderedDict()
+ cocoGt = self.coco
+ #print(result_files)
+ #asas
+ for loop, an in enumerate(cocoGt.anns):
+ cocoGt.anns[loop]['segmentation'] = cocoGt.anns[loop]['segmentation'][0]['visible']
+
+ #cocoGt_full = self.coco
+ #print(cocoGt.anns[0]['segmentation'][0]['visible'][0])
+ #asas
+ for metric in metrics:
+ msg = f'Evaluating {metric}...'
+ if logger is None:
+ msg = '\n' + msg
+ print_log(msg, logger=logger)
+
+ if metric == 'proposal_fast':
+ ar = self.fast_eval_recall(
+ results, proposal_nums, iou_thrs, logger='silent')
+ log_msg = []
+ for i, num in enumerate(proposal_nums):
+ eval_results[f'AR@{num}'] = ar[i]
+ log_msg.append(f'\nAR@{num}\t{ar[i]:.4f}')
+ log_msg = ''.join(log_msg)
+ print_log(log_msg, logger=logger)
+ continue
+
+ if metric not in result_files:
+ raise KeyError(f'{metric} is not in results')
+ try:
+ cocoDt = cocoGt.loadRes(result_files[metric])
+ except IndexError:
+ print_log(
+ 'The testing results of the whole dataset is empty.',
+ logger=logger,
+ level=logging.ERROR)
+ break
+
+ iou_type = 'bbox' if metric == 'proposal' else metric
+ #print(cocoGt.anns[0]['segmentation'])
+ #print(cocoDt.anns['1'])#[0]['segmentation'])
+ #asas
+ cocoEval = COCOeval(cocoGt, cocoDt, iou_type)
+ #cocoEval = COCOeval(cocoGt, cocoDt, 'asas')
+ cocoEval.params.catIds = self.cat_ids
+ cocoEval.params.imgIds = self.img_ids
+ cocoEval.params.maxDets = list(proposal_nums)
+ cocoEval.params.iouThrs = iou_thrs
+ # mapping of cocoEval.stats
+ coco_metric_names = {
+ 'mAP': 0,
+ 'mAP_50': 1,
+ 'mAP_75': 2,
+ 'mAP_s': 3,
+ 'mAP_m': 4,
+ 'mAP_l': 5,
+ 'AR@100': 6,
+ 'AR@300': 7,
+ 'AR@1000': 8,
+ 'AR_s@1000': 9,
+ 'AR_m@1000': 10,
+ 'AR_l@1000': 11
+ }
+ if metric_items is not None:
+ for metric_item in metric_items:
+ if metric_item not in coco_metric_names:
+ raise KeyError(
+ f'metric item {metric_item} is not supported')
+
+ if metric == 'proposal':
+ cocoEval.params.useCats = 0
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if metric_items is None:
+ metric_items = [
+ 'AR@100', 'AR@300', 'AR@1000', 'AR_s@1000',
+ 'AR_m@1000', 'AR_l@1000'
+ ]
+
+ for item in metric_items:
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[item]]:.3f}')
+ eval_results[item] = val
+ else:
+ cocoEval.evaluate()
+ cocoEval.accumulate()
+ cocoEval.summarize()
+ if classwise: # Compute per-category AP
+ # Compute per-category AP
+ # from https://github.com/facebookresearch/detectron2/
+ precisions = cocoEval.eval['precision']
+ # precision: (iou, recall, cls, area range, max dets)
+ assert len(self.cat_ids) == precisions.shape[2]
+
+ results_per_category = []
+ for idx, catId in enumerate(self.cat_ids):
+ # area range index 0: all area ranges
+ # max dets index -1: typically 100 per image
+ nm = self.coco.loadCats(catId)[0]
+ precision = precisions[:, :, idx, 0, -1]
+ precision = precision[precision > -1]
+ if precision.size:
+ ap = np.mean(precision)
+ else:
+ ap = float('nan')
+ results_per_category.append(
+ (f'{nm["name"]}', f'{float(ap):0.3f}'))
+
+ num_columns = min(6, len(results_per_category) * 2)
+ results_flatten = list(
+ itertools.chain(*results_per_category))
+ headers = ['category', 'AP'] * (num_columns // 2)
+ results_2d = itertools.zip_longest(*[
+ results_flatten[i::num_columns]
+ for i in range(num_columns)
+ ])
+ table_data = [headers]
+ table_data += [result for result in results_2d]
+ table = AsciiTable(table_data)
+ print_log('\n' + table.table, logger=logger)
+
+ if metric_items is None:
+ metric_items = [
+ 'mAP', 'mAP_50', 'mAP_75', 'mAP_s', 'mAP_m', 'mAP_l'
+ ]
+
+ for metric_item in metric_items:
+ key = f'{metric}_{metric_item}'
+ val = float(
+ f'{cocoEval.stats[coco_metric_names[metric_item]]:.3f}'
+ )
+ eval_results[key] = val
+ ap = cocoEval.stats[:6]
+ eval_results[f'{metric}_mAP_copypaste'] = (
+ f'{ap[0]:.3f} {ap[1]:.3f} {ap[2]:.3f} {ap[3]:.3f} '
+ f'{ap[4]:.3f} {ap[5]:.3f}')
+ if tmp_dir is not None:
+ tmp_dir.cleanup()
+ return eval_results
diff --git a/walt/train.py b/walt/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce383e27d5d067d490bc57f3b861e4fb0fefda50
--- /dev/null
+++ b/walt/train.py
@@ -0,0 +1,188 @@
+import argparse
+import copy
+import os
+import os.path as osp
+import time
+import warnings
+
+import mmcv
+import torch
+from mmcv import Config, DictAction
+from mmcv.runner import get_dist_info, init_dist
+from mmcv.utils import get_git_hash
+
+from mmdet import __version__
+from mmdet.apis import set_random_seed
+from code_local.apis import train_detector
+from code_local.datasets import build_dataset
+from mmdet.models import build_detector
+from mmdet.utils import collect_env, get_root_logger
+
+
+def parse_args():
+ parser = argparse.ArgumentParser(description='Train a detector')
+ parser.add_argument('config', help='train config file path')
+ parser.add_argument('--work-dir', help='the dir to save logs and models')
+ parser.add_argument(
+ '--resume-from', help='the checkpoint file to resume from')
+ parser.add_argument(
+ '--no-validate',
+ action='store_true',
+ help='whether not to evaluate the checkpoint during training')
+ group_gpus = parser.add_mutually_exclusive_group()
+ group_gpus.add_argument(
+ '--gpus',
+ type=int,
+ help='number of gpus to use '
+ '(only applicable to non-distributed training)')
+ group_gpus.add_argument(
+ '--gpu-ids',
+ type=int,
+ nargs='+',
+ help='ids of gpus to use '
+ '(only applicable to non-distributed training)')
+ parser.add_argument('--seed', type=int, default=None, help='random seed')
+ parser.add_argument(
+ '--deterministic',
+ action='store_true',
+ help='whether to set deterministic options for CUDNN backend.')
+ parser.add_argument(
+ '--options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file (deprecate), '
+ 'change to --cfg-options instead.')
+ parser.add_argument(
+ '--cfg-options',
+ nargs='+',
+ action=DictAction,
+ help='override some settings in the used config, the key-value pair '
+ 'in xxx=yyy format will be merged into config file. If the value to '
+ 'be overwritten is a list, it should be like key="[a,b]" or key=a,b '
+ 'It also allows nested list/tuple values, e.g. key="[(a,b),(c,d)]" '
+ 'Note that the quotation marks are necessary and that no white space '
+ 'is allowed.')
+ parser.add_argument(
+ '--launcher',
+ choices=['none', 'pytorch', 'slurm', 'mpi'],
+ default='none',
+ help='job launcher')
+ parser.add_argument('--local_rank', type=int, default=0)
+ args = parser.parse_args()
+ if 'LOCAL_RANK' not in os.environ:
+ os.environ['LOCAL_RANK'] = str(args.local_rank)
+
+ if args.options and args.cfg_options:
+ raise ValueError(
+ '--options and --cfg-options cannot be both '
+ 'specified, --options is deprecated in favor of --cfg-options')
+ if args.options:
+ warnings.warn('--options is deprecated in favor of --cfg-options')
+ args.cfg_options = args.options
+
+ return args
+
+
+def main():
+ args = parse_args()
+
+ cfg = Config.fromfile(args.config)
+ if args.cfg_options is not None:
+ cfg.merge_from_dict(args.cfg_options)
+ # import modules from string list.
+ if cfg.get('custom_imports', None):
+ from mmcv.utils import import_modules_from_strings
+ import_modules_from_strings(**cfg['custom_imports'])
+ # set cudnn_benchmark
+ if cfg.get('cudnn_benchmark', False):
+ torch.backends.cudnn.benchmark = True
+
+ # work_dir is determined in this priority: CLI > segment in file > filename
+ if args.work_dir is not None:
+ # update configs according to CLI args if args.work_dir is not None
+ cfg.work_dir = args.work_dir
+ elif cfg.get('work_dir', None) is None:
+ # use config filename as default work_dir if cfg.work_dir is None
+ cfg.work_dir = osp.join('./work_dirs',
+ osp.splitext(osp.basename(args.config))[0])
+ if args.resume_from is not None:
+ cfg.resume_from = args.resume_from
+ if args.gpu_ids is not None:
+ cfg.gpu_ids = args.gpu_ids
+ else:
+ cfg.gpu_ids = range(1) if args.gpus is None else range(args.gpus)
+
+ # init distributed env first, since logger depends on the dist info.
+ if args.launcher == 'none':
+ distributed = False
+ else:
+ distributed = True
+ init_dist(args.launcher, **cfg.dist_params)
+ # re-set gpu_ids with distributed training mode
+ _, world_size = get_dist_info()
+ cfg.gpu_ids = range(world_size)
+
+ # create work_dir
+ mmcv.mkdir_or_exist(osp.abspath(cfg.work_dir))
+ # dump config
+ cfg.dump(osp.join(cfg.work_dir, osp.basename(args.config)))
+ # init the logger before other steps
+ timestamp = time.strftime('%Y%m%d_%H%M%S', time.localtime())
+ log_file = osp.join(cfg.work_dir, f'{timestamp}.log')
+ logger = get_root_logger(log_file=log_file, log_level=cfg.log_level)
+
+ # init the meta dict to record some important information such as
+ # environment info and seed, which will be logged
+ meta = dict()
+ # log env info
+ env_info_dict = collect_env()
+ env_info = '\n'.join([(f'{k}: {v}') for k, v in env_info_dict.items()])
+ dash_line = '-' * 60 + '\n'
+ logger.info('Environment info:\n' + dash_line + env_info + '\n' +
+ dash_line)
+ meta['env_info'] = env_info
+ meta['config'] = cfg.pretty_text
+ # log some basic info
+ logger.info(f'Distributed training: {distributed}')
+ logger.info(f'Config:\n{cfg.pretty_text}')
+
+ # set random seeds
+ if args.seed is not None:
+ logger.info(f'Set random seed to {args.seed}, '
+ f'deterministic: {args.deterministic}')
+ set_random_seed(args.seed, deterministic=args.deterministic)
+ cfg.seed = args.seed
+ meta['seed'] = args.seed
+ meta['exp_name'] = osp.basename(args.config)
+
+ model = build_detector(
+ cfg.model,
+ train_cfg=cfg.get('train_cfg'),
+ test_cfg=cfg.get('test_cfg'))
+
+ datasets = [build_dataset(cfg.data.train)]
+ if len(cfg.workflow) == 2:
+ val_dataset = copy.deepcopy(cfg.data.val)
+ val_dataset.pipeline = cfg.data.train.pipeline
+ datasets.append(build_dataset(val_dataset))
+ if cfg.checkpoint_config is not None:
+ # save mmdet version, config file content and class names in
+ # checkpoints as meta data
+ cfg.checkpoint_config.meta = dict(
+ mmdet_version=__version__ + get_git_hash()[:7],
+ CLASSES=datasets[0].CLASSES)
+ # add an attribute for visualization convenience
+ model.CLASSES = datasets[0].CLASSES
+ train_detector(
+ model,
+ datasets,
+ cfg,
+ distributed=distributed,
+ validate=(not args.no_validate),
+ timestamp=timestamp,
+ meta=meta)
+
+
+if __name__ == '__main__':
+ main()