Your Name commited on
Commit
a56642d
1 Parent(s): 86c421b

update demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LICENSE +20 -0
  2. README.md +6 -5
  3. app.py +79 -0
  4. configs/_base_/datasets/parking_instance.py +48 -0
  5. configs/_base_/datasets/parking_instance_coco.py +49 -0
  6. configs/_base_/datasets/people_real_coco.py +49 -0
  7. configs/_base_/datasets/walt_people.py +49 -0
  8. configs/_base_/datasets/walt_vehicle.py +49 -0
  9. configs/_base_/default_runtime.py +16 -0
  10. configs/_base_/models/mask_rcnn_swin_fpn.py +127 -0
  11. configs/_base_/models/occ_mask_rcnn_swin_fpn.py +127 -0
  12. configs/_base_/schedules/schedule_1x.py +11 -0
  13. configs/walt/walt_people.py +80 -0
  14. configs/walt/walt_vehicle.py +80 -0
  15. cwalt/CWALT.py +161 -0
  16. cwalt/Clip_WALT_Generate.py +284 -0
  17. cwalt/Download_Detections.py +28 -0
  18. cwalt/clustering_utils.py +132 -0
  19. cwalt/kmedoid.py +55 -0
  20. cwalt/utils.py +168 -0
  21. cwalt_generate.py +14 -0
  22. docker/Dockerfile +52 -0
  23. github_vis/cwalt.gif +0 -0
  24. github_vis/vis_cars.gif +0 -0
  25. github_vis/vis_people.gif +0 -0
  26. infer.py +118 -0
  27. mmcv_custom/__init__.py +5 -0
  28. mmcv_custom/checkpoint.py +500 -0
  29. mmcv_custom/runner/__init__.py +8 -0
  30. mmcv_custom/runner/checkpoint.py +85 -0
  31. mmcv_custom/runner/epoch_based_runner.py +104 -0
  32. mmdet/__init__.py +28 -0
  33. mmdet/apis/__init__.py +10 -0
  34. mmdet/apis/inference.py +217 -0
  35. mmdet/apis/test.py +189 -0
  36. mmdet/apis/train.py +185 -0
  37. mmdet/core/__init__.py +7 -0
  38. mmdet/core/anchor/__init__.py +11 -0
  39. mmdet/core/anchor/anchor_generator.py +727 -0
  40. mmdet/core/anchor/builder.py +7 -0
  41. mmdet/core/anchor/point_generator.py +37 -0
  42. mmdet/core/anchor/utils.py +71 -0
  43. mmdet/core/bbox/__init__.py +27 -0
  44. mmdet/core/bbox/assigners/__init__.py +16 -0
  45. mmdet/core/bbox/assigners/approx_max_iou_assigner.py +145 -0
  46. mmdet/core/bbox/assigners/assign_result.py +204 -0
  47. mmdet/core/bbox/assigners/atss_assigner.py +178 -0
  48. mmdet/core/bbox/assigners/base_assigner.py +9 -0
  49. mmdet/core/bbox/assigners/center_region_assigner.py +335 -0
  50. mmdet/core/bbox/assigners/grid_assigner.py +155 -0
LICENSE ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright (c) 2022-2022 dinesh reddy and others
2
+
3
+ Permission is hereby granted, free of charge, to any person obtaining
4
+ a copy of this software and associated documentation files (the
5
+ "Software"), to deal in the Software without restriction, including
6
+ without limitation the rights to use, copy, modify, merge, publish,
7
+ distribute, sublicense, and/or sell copies of the Software, and to
8
+ permit persons to whom the Software is furnished to do so, subject to
9
+ the following conditions:
10
+
11
+ The above copyright notice and this permission notice shall be
12
+ included in all copies or substantial portions of the Software.
13
+
14
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
15
+ EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
16
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
17
+ NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE
18
+ LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
19
+ OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION
20
+ WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
README.md CHANGED
@@ -1,12 +1,13 @@
1
  ---
2
- title: WALT
3
- emoji: 👁
4
- colorFrom: purple
5
- colorTo: yellow
6
  sdk: gradio
7
- sdk_version: 3.0.21
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
1
  ---
2
+ title: WALT DEMO
3
+ emoji:
4
+ colorFrom: indigo
5
+ colorTo: indigo
6
  sdk: gradio
7
+ sdk_version: 3.0.20
8
  app_file: app.py
9
  pinned: false
10
+ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import gradio as gr
4
+ from infer import detections
5
+ '''
6
+ import os
7
+ os.system("mkdir data")
8
+ os.system("mkdir data/models")
9
+ os.system("wget https://www.cs.cmu.edu/~walt/models/walt_people.pth -O data/models/walt_people.pth")
10
+ os.system("wget https://www.cs.cmu.edu/~walt/models/walt_vehicle.pth -O data/models/walt_vehicle.pth")
11
+ '''
12
+ def walt_demo(input_img, confidence_threshold):
13
+ #detect_people = detections('configs/walt/walt_people.py', 'cuda:0', model_path='data/models/walt_people.pth')
14
+ if torch.cuda.is_available() == False:
15
+ device='cpu'
16
+ else:
17
+ device='cuda:0'
18
+ #detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
19
+ detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth', threshold=confidence_threshold)
20
+
21
+ count = 0
22
+ #img = detect_people.run_on_image(input_img)
23
+ output_img = detect.run_on_image(input_img)
24
+ #try:
25
+ #except:
26
+ # print("detecting on image failed")
27
+
28
+ return output_img
29
+
30
+ description = """
31
+ WALT Demo on WALT dataset. After watching and automatically learning for several days, this approach shows significant performance improvement in detecting and segmenting occluded people and vehicles, over human-supervised amodal approaches</b>.
32
+ <center>
33
+ <a href="https://www.cs.cmu.edu/~walt/">
34
+ <img style="display:inline" alt="Project page" src="https://img.shields.io/badge/Project%20Page-WALT-green">
35
+ </a>
36
+ <a href="https://www.cs.cmu.edu/~walt/pdf/walt.pdf"><img style="display:inline" src="https://img.shields.io/badge/Paper-Pdf-red"></a>
37
+ <a href="https://github.com/dineshreddy91/WALT"><img style="display:inline" src="https://img.shields.io/github/stars/dineshreddy91/WALT?style=social"></a>
38
+ </center>
39
+ """
40
+ title = "WALT:Watch And Learn 2D Amodal Representation using Time-lapse Imagery"
41
+ article="""
42
+ <center>
43
+ <img src='https://visitor-badge.glitch.me/badge?page_id=anhquancao.MonoScene&left_color=darkmagenta&right_color=purple' alt='visitor badge'>
44
+ </center>
45
+ """
46
+
47
+ examples = [
48
+ ['demo/images/img_1.jpg',0.8],
49
+ ['demo/images/img_2.jpg',0.8],
50
+ ['demo/images/img_4.png',0.85],
51
+ ]
52
+
53
+ '''
54
+ import cv2
55
+ filename='demo/images/img_1.jpg'
56
+ img=cv2.imread(filename)
57
+ img=walt_demo(img)
58
+ cv2.imwrite(filename.replace('/images/','/results/'),img)
59
+ cv2.imwrite('check.png',img)
60
+ '''
61
+ confidence_threshold = gr.Slider(minimum=0.3,
62
+ maximum=1.0,
63
+ step=0.01,
64
+ value=1.0,
65
+ label="Amodal Detection Confidence Threshold")
66
+ inputs = [gr.Image(), confidence_threshold]
67
+ demo = gr.Interface(walt_demo,
68
+ outputs="image",
69
+ inputs=inputs,
70
+ article=article,
71
+ title=title,
72
+ enable_queue=True,
73
+ examples=examples,
74
+ description=description)
75
+
76
+ #demo.launch(server_name="0.0.0.0", server_port=7000)
77
+ demo.launch()
78
+
79
+
configs/_base_/datasets/parking_instance.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'ParkingDataset'
2
+ data_root = 'data/parking/'
3
+ img_norm_cfg = dict(
4
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
5
+ train_pipeline = [
6
+ dict(type='LoadImageFromFile'),
7
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
8
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
9
+ dict(type='RandomFlip', flip_ratio=0.5),
10
+ dict(type='Normalize', **img_norm_cfg),
11
+ dict(type='Pad', size_divisor=32),
12
+ dict(type='DefaultFormatBundle'),
13
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_bboxes_3d','gt_bboxes_3d_proj']),
14
+ ]
15
+ test_pipeline = [
16
+ dict(type='LoadImageFromFile'),
17
+ dict(
18
+ type='MultiScaleFlipAug',
19
+ img_scale=(1333, 800),
20
+ flip=False,
21
+ transforms=[
22
+ dict(type='Resize', keep_ratio=True),
23
+ dict(type='RandomFlip'),
24
+ dict(type='Normalize', **img_norm_cfg),
25
+ dict(type='Pad', size_divisor=32),
26
+ dict(type='ImageToTensor', keys=['img']),
27
+ dict(type='Collect', keys=['img']),
28
+ ])
29
+ ]
30
+ data = dict(
31
+ samples_per_gpu=1,
32
+ workers_per_gpu=1,
33
+ train=dict(
34
+ type=dataset_type,
35
+ ann_file=data_root + 'GT_data/',
36
+ img_prefix=data_root + 'images/',
37
+ pipeline=train_pipeline),
38
+ val=dict(
39
+ type=dataset_type,
40
+ ann_file=data_root + 'GT_data/',
41
+ img_prefix=data_root + 'images/',
42
+ pipeline=test_pipeline),
43
+ test=dict(
44
+ type=dataset_type,
45
+ ann_file=data_root + 'GT_data/',
46
+ img_prefix=data_root + 'images/',
47
+ pipeline=test_pipeline))
48
+ evaluation = dict(metric=['bbox'])#, 'segm'])
configs/_base_/datasets/parking_instance_coco.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'ParkingCocoDataset'
2
+ data_root = 'data/parking/'
3
+ data_root_test = 'data/parking_highres/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=6,
33
+ workers_per_gpu=6,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + 'GT_data/',
37
+ img_prefix=data_root + 'images/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + 'GT_data/',
42
+ img_prefix=data_root_test + 'images',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + 'GT_data/',
47
+ img_prefix=data_root_test + 'images',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/people_real_coco.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=8,
33
+ workers_per_gpu=8,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/walt_people.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=8,
33
+ workers_per_gpu=8,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/datasets/walt_vehicle.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset_type = 'WaltDataset'
2
+ data_root = 'data/cwalt_train/'
3
+ data_root_test = 'data/cwalt_test/'
4
+ img_norm_cfg = dict(
5
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
6
+ train_pipeline = [
7
+ dict(type='LoadImageFromFile'),
8
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
9
+ dict(type='Resize', img_scale=(1333, 800), keep_ratio=True),
10
+ dict(type='RandomFlip', flip_ratio=0.5),
11
+ dict(type='Normalize', **img_norm_cfg),
12
+ dict(type='Pad', size_divisor=32),
13
+ dict(type='DefaultFormatBundle'),
14
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
15
+ ]
16
+ test_pipeline = [
17
+ dict(type='LoadImageFromFile'),
18
+ dict(
19
+ type='MultiScaleFlipAug',
20
+ img_scale=(1333, 800),
21
+ flip=False,
22
+ transforms=[
23
+ dict(type='Resize', keep_ratio=True),
24
+ dict(type='RandomFlip'),
25
+ dict(type='Normalize', **img_norm_cfg),
26
+ dict(type='Pad', size_divisor=32),
27
+ dict(type='ImageToTensor', keys=['img']),
28
+ dict(type='Collect', keys=['img']),
29
+ ])
30
+ ]
31
+ data = dict(
32
+ samples_per_gpu=5,
33
+ workers_per_gpu=5,
34
+ train=dict(
35
+ type=dataset_type,
36
+ ann_file=data_root + '/',
37
+ img_prefix=data_root + '/',
38
+ pipeline=train_pipeline),
39
+ val=dict(
40
+ type=dataset_type,
41
+ ann_file=data_root_test + '/',
42
+ img_prefix=data_root_test + '/',
43
+ pipeline=test_pipeline),
44
+ test=dict(
45
+ type=dataset_type,
46
+ ann_file=data_root_test + '/',
47
+ img_prefix=data_root_test + '/',
48
+ pipeline=test_pipeline))
49
+ evaluation = dict(metric=['bbox', 'segm'])
configs/_base_/default_runtime.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ checkpoint_config = dict(interval=1)
2
+ # yapf:disable
3
+ log_config = dict(
4
+ interval=50,
5
+ hooks=[
6
+ dict(type='TextLoggerHook'),
7
+ # dict(type='TensorboardLoggerHook')
8
+ ])
9
+ # yapf:enable
10
+ custom_hooks = [dict(type='NumClassCheckHook')]
11
+
12
+ dist_params = dict(backend='nccl')
13
+ log_level = 'INFO'
14
+ load_from = None
15
+ resume_from = None
16
+ workflow = [('train', 1)]
configs/_base_/models/mask_rcnn_swin_fpn.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ model = dict(
3
+ type='MaskRCNN',
4
+ pretrained=None,
5
+ backbone=dict(
6
+ type='SwinTransformer',
7
+ embed_dim=96,
8
+ depths=[2, 2, 6, 2],
9
+ num_heads=[3, 6, 12, 24],
10
+ window_size=7,
11
+ mlp_ratio=4.,
12
+ qkv_bias=True,
13
+ qk_scale=None,
14
+ drop_rate=0.,
15
+ attn_drop_rate=0.,
16
+ drop_path_rate=0.2,
17
+ ape=False,
18
+ patch_norm=True,
19
+ out_indices=(0, 1, 2, 3),
20
+ use_checkpoint=False),
21
+ neck=dict(
22
+ type='FPN',
23
+ in_channels=[96, 192, 384, 768],
24
+ out_channels=256,
25
+ num_outs=5),
26
+ rpn_head=dict(
27
+ type='RPNHead',
28
+ in_channels=256,
29
+ feat_channels=256,
30
+ anchor_generator=dict(
31
+ type='AnchorGenerator',
32
+ scales=[8],
33
+ ratios=[0.5, 1.0, 2.0],
34
+ strides=[4, 8, 16, 32, 64]),
35
+ bbox_coder=dict(
36
+ type='DeltaXYWHBBoxCoder',
37
+ target_means=[.0, .0, .0, .0],
38
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
39
+ loss_cls=dict(
40
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
41
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
42
+ roi_head=dict(
43
+ type='StandardRoIHead',
44
+ bbox_roi_extractor=dict(
45
+ type='SingleRoIExtractor',
46
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
47
+ out_channels=256,
48
+ featmap_strides=[4, 8, 16, 32]),
49
+ bbox_head=dict(
50
+ type='Shared2FCBBoxHead',
51
+ in_channels=256,
52
+ fc_out_channels=1024,
53
+ roi_feat_size=7,
54
+ num_classes=80,
55
+ bbox_coder=dict(
56
+ type='DeltaXYWHBBoxCoder',
57
+ target_means=[0., 0., 0., 0.],
58
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
59
+ reg_class_agnostic=False,
60
+ loss_cls=dict(
61
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
62
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
63
+ mask_roi_extractor=dict(
64
+ type='SingleRoIExtractor',
65
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
66
+ out_channels=256,
67
+ featmap_strides=[4, 8, 16, 32]),
68
+ mask_head=dict(
69
+ type='FCNMaskHead',
70
+ num_convs=4,
71
+ in_channels=256,
72
+ conv_out_channels=256,
73
+ num_classes=80,
74
+ loss_mask=dict(
75
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
76
+ # model training and testing settings
77
+ train_cfg=dict(
78
+ rpn=dict(
79
+ assigner=dict(
80
+ type='MaxIoUAssigner',
81
+ pos_iou_thr=0.7,
82
+ neg_iou_thr=0.3,
83
+ min_pos_iou=0.3,
84
+ match_low_quality=True,
85
+ ignore_iof_thr=-1),
86
+ sampler=dict(
87
+ type='RandomSampler',
88
+ num=256,
89
+ pos_fraction=0.5,
90
+ neg_pos_ub=-1,
91
+ add_gt_as_proposals=False),
92
+ allowed_border=-1,
93
+ pos_weight=-1,
94
+ debug=False),
95
+ rpn_proposal=dict(
96
+ nms_pre=2000,
97
+ max_per_img=1000,
98
+ nms=dict(type='nms', iou_threshold=0.7),
99
+ min_bbox_size=0),
100
+ rcnn=dict(
101
+ assigner=dict(
102
+ type='MaxIoUAssigner',
103
+ pos_iou_thr=0.5,
104
+ neg_iou_thr=0.5,
105
+ min_pos_iou=0.5,
106
+ match_low_quality=True,
107
+ ignore_iof_thr=-1),
108
+ sampler=dict(
109
+ type='RandomSampler',
110
+ num=512,
111
+ pos_fraction=0.25,
112
+ neg_pos_ub=-1,
113
+ add_gt_as_proposals=True),
114
+ mask_size=28,
115
+ pos_weight=-1,
116
+ debug=False)),
117
+ test_cfg=dict(
118
+ rpn=dict(
119
+ nms_pre=1000,
120
+ max_per_img=1000,
121
+ nms=dict(type='nms', iou_threshold=0.7),
122
+ min_bbox_size=0),
123
+ rcnn=dict(
124
+ score_thr=0.05,
125
+ nms=dict(type='nms', iou_threshold=0.5),
126
+ max_per_img=100,
127
+ mask_thr_binary=0.5)))
configs/_base_/models/occ_mask_rcnn_swin_fpn.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # model settings
2
+ model = dict(
3
+ type='MaskRCNN',
4
+ pretrained=None,
5
+ backbone=dict(
6
+ type='SwinTransformer',
7
+ embed_dim=96,
8
+ depths=[2, 2, 6, 2],
9
+ num_heads=[3, 6, 12, 24],
10
+ window_size=7,
11
+ mlp_ratio=4.,
12
+ qkv_bias=True,
13
+ qk_scale=None,
14
+ drop_rate=0.,
15
+ attn_drop_rate=0.,
16
+ drop_path_rate=0.2,
17
+ ape=False,
18
+ patch_norm=True,
19
+ out_indices=(0, 1, 2, 3),
20
+ use_checkpoint=False),
21
+ neck=dict(
22
+ type='FPN',
23
+ in_channels=[96, 192, 384, 768],
24
+ out_channels=256,
25
+ num_outs=5),
26
+ rpn_head=dict(
27
+ type='RPNHead',
28
+ in_channels=256,
29
+ feat_channels=256,
30
+ anchor_generator=dict(
31
+ type='AnchorGenerator',
32
+ scales=[8],
33
+ ratios=[0.5, 1.0, 2.0],
34
+ strides=[4, 8, 16, 32, 64]),
35
+ bbox_coder=dict(
36
+ type='DeltaXYWHBBoxCoder',
37
+ target_means=[.0, .0, .0, .0],
38
+ target_stds=[1.0, 1.0, 1.0, 1.0]),
39
+ loss_cls=dict(
40
+ type='CrossEntropyLoss', use_sigmoid=True, loss_weight=1.0),
41
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
42
+ roi_head=dict(
43
+ type='StandardRoIHead',
44
+ bbox_roi_extractor=dict(
45
+ type='SingleRoIExtractor',
46
+ roi_layer=dict(type='RoIAlign', output_size=7, sampling_ratio=0),
47
+ out_channels=256,
48
+ featmap_strides=[4, 8, 16, 32]),
49
+ bbox_head=dict(
50
+ type='Shared2FCBBoxHead',
51
+ in_channels=256,
52
+ fc_out_channels=1024,
53
+ roi_feat_size=7,
54
+ num_classes=80,
55
+ bbox_coder=dict(
56
+ type='DeltaXYWHBBoxCoder',
57
+ target_means=[0., 0., 0., 0.],
58
+ target_stds=[0.1, 0.1, 0.2, 0.2]),
59
+ reg_class_agnostic=False,
60
+ loss_cls=dict(
61
+ type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0),
62
+ loss_bbox=dict(type='L1Loss', loss_weight=1.0)),
63
+ mask_roi_extractor=dict(
64
+ type='SingleRoIExtractor',
65
+ roi_layer=dict(type='RoIAlign', output_size=14, sampling_ratio=0),
66
+ out_channels=256,
67
+ featmap_strides=[4, 8, 16, 32]),
68
+ mask_head=dict(
69
+ type='FCNOccMaskHead',
70
+ num_convs=4,
71
+ in_channels=256,
72
+ conv_out_channels=256,
73
+ num_classes=80,
74
+ loss_mask=dict(
75
+ type='CrossEntropyLoss', use_mask=True, loss_weight=1.0))),
76
+ # model training and testing settings
77
+ train_cfg=dict(
78
+ rpn=dict(
79
+ assigner=dict(
80
+ type='MaxIoUAssigner',
81
+ pos_iou_thr=0.7,
82
+ neg_iou_thr=0.3,
83
+ min_pos_iou=0.3,
84
+ match_low_quality=True,
85
+ ignore_iof_thr=-1),
86
+ sampler=dict(
87
+ type='RandomSampler',
88
+ num=256,
89
+ pos_fraction=0.5,
90
+ neg_pos_ub=-1,
91
+ add_gt_as_proposals=False),
92
+ allowed_border=-1,
93
+ pos_weight=-1,
94
+ debug=False),
95
+ rpn_proposal=dict(
96
+ nms_pre=2000,
97
+ max_per_img=1000,
98
+ nms=dict(type='nms', iou_threshold=0.7),
99
+ min_bbox_size=0),
100
+ rcnn=dict(
101
+ assigner=dict(
102
+ type='MaxIoUAssigner',
103
+ pos_iou_thr=0.5,
104
+ neg_iou_thr=0.5,
105
+ min_pos_iou=0.5,
106
+ match_low_quality=True,
107
+ ignore_iof_thr=-1),
108
+ sampler=dict(
109
+ type='RandomSampler',
110
+ num=512,
111
+ pos_fraction=0.25,
112
+ neg_pos_ub=-1,
113
+ add_gt_as_proposals=True),
114
+ mask_size=28,
115
+ pos_weight=-1,
116
+ debug=False)),
117
+ test_cfg=dict(
118
+ rpn=dict(
119
+ nms_pre=1000,
120
+ max_per_img=1000,
121
+ nms=dict(type='nms', iou_threshold=0.7),
122
+ min_bbox_size=0),
123
+ rcnn=dict(
124
+ score_thr=0.05,
125
+ nms=dict(type='nms', iou_threshold=0.5),
126
+ max_per_img=100,
127
+ mask_thr_binary=0.5)))
configs/_base_/schedules/schedule_1x.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # optimizer
2
+ optimizer = dict(type='SGD', lr=0.02, momentum=0.9, weight_decay=0.0001)
3
+ optimizer_config = dict(grad_clip=None)
4
+ # learning policy
5
+ lr_config = dict(
6
+ policy='step',
7
+ warmup='linear',
8
+ warmup_iters=500,
9
+ warmup_ratio=0.001,
10
+ step=[8, 11])
11
+ runner = dict(type='EpochBasedRunner', max_epochs=12)
configs/walt/walt_people.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
3
+ '../_base_/datasets/walt_people.py',
4
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ embed_dim=96,
10
+ depths=[2, 2, 6, 2],
11
+ num_heads=[3, 6, 12, 24],
12
+ window_size=7,
13
+ ape=False,
14
+ drop_path_rate=0.1,
15
+ patch_norm=True,
16
+ use_checkpoint=False
17
+ ),
18
+ neck=dict(in_channels=[96, 192, 384, 768]))
19
+
20
+ img_norm_cfg = dict(
21
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
22
+
23
+ # augmentation strategy originates from DETR / Sparse RCNN
24
+ train_pipeline = [
25
+ dict(type='LoadImageFromFile'),
26
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
27
+ dict(type='RandomFlip', flip_ratio=0.5),
28
+ dict(type='AutoAugment',
29
+ policies=[
30
+ [
31
+ dict(type='Resize',
32
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
33
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
34
+ (736, 1333), (768, 1333), (800, 1333)],
35
+ multiscale_mode='value',
36
+ keep_ratio=True)
37
+ ],
38
+ [
39
+ dict(type='Resize',
40
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
41
+ multiscale_mode='value',
42
+ keep_ratio=True),
43
+ dict(type='RandomCrop',
44
+ crop_type='absolute_range',
45
+ crop_size=(384, 600),
46
+ allow_negative_crop=True),
47
+ dict(type='Resize',
48
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
49
+ (576, 1333), (608, 1333), (640, 1333),
50
+ (672, 1333), (704, 1333), (736, 1333),
51
+ (768, 1333), (800, 1333)],
52
+ multiscale_mode='value',
53
+ override=True,
54
+ keep_ratio=True)
55
+ ]
56
+ ]),
57
+ dict(type='Normalize', **img_norm_cfg),
58
+ dict(type='Pad', size_divisor=32),
59
+ dict(type='DefaultFormatBundle'),
60
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
61
+ ]
62
+ data = dict(train=dict(pipeline=train_pipeline))
63
+
64
+ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
65
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
66
+ 'relative_position_bias_table': dict(decay_mult=0.),
67
+ 'norm': dict(decay_mult=0.)}))
68
+ lr_config = dict(step=[8, 11])
69
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
70
+
71
+ # do not use mmdet version fp16
72
+ fp16 = None
73
+ optimizer_config = dict(
74
+ type="DistOptimizerHook",
75
+ update_interval=1,
76
+ grad_clip=None,
77
+ coalesce=True,
78
+ bucket_size_mb=-1,
79
+ use_fp16=True,
80
+ )
configs/walt/walt_vehicle.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _base_ = [
2
+ '../_base_/models/occ_mask_rcnn_swin_fpn.py',
3
+ '../_base_/datasets/walt_vehicle.py',
4
+ '../_base_/schedules/schedule_1x.py', '../_base_/default_runtime.py'
5
+ ]
6
+
7
+ model = dict(
8
+ backbone=dict(
9
+ embed_dim=96,
10
+ depths=[2, 2, 6, 2],
11
+ num_heads=[3, 6, 12, 24],
12
+ window_size=7,
13
+ ape=False,
14
+ drop_path_rate=0.1,
15
+ patch_norm=True,
16
+ use_checkpoint=False
17
+ ),
18
+ neck=dict(in_channels=[96, 192, 384, 768]))
19
+
20
+ img_norm_cfg = dict(
21
+ mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
22
+
23
+ # augmentation strategy originates from DETR / Sparse RCNN
24
+ train_pipeline = [
25
+ dict(type='LoadImageFromFile'),
26
+ dict(type='LoadAnnotations', with_bbox=True, with_mask=True),
27
+ dict(type='RandomFlip', flip_ratio=0.5),
28
+ dict(type='AutoAugment',
29
+ policies=[
30
+ [
31
+ dict(type='Resize',
32
+ img_scale=[(480, 1333), (512, 1333), (544, 1333), (576, 1333),
33
+ (608, 1333), (640, 1333), (672, 1333), (704, 1333),
34
+ (736, 1333), (768, 1333), (800, 1333)],
35
+ multiscale_mode='value',
36
+ keep_ratio=True)
37
+ ],
38
+ [
39
+ dict(type='Resize',
40
+ img_scale=[(400, 1333), (500, 1333), (600, 1333)],
41
+ multiscale_mode='value',
42
+ keep_ratio=True),
43
+ dict(type='RandomCrop',
44
+ crop_type='absolute_range',
45
+ crop_size=(384, 600),
46
+ allow_negative_crop=True),
47
+ dict(type='Resize',
48
+ img_scale=[(480, 1333), (512, 1333), (544, 1333),
49
+ (576, 1333), (608, 1333), (640, 1333),
50
+ (672, 1333), (704, 1333), (736, 1333),
51
+ (768, 1333), (800, 1333)],
52
+ multiscale_mode='value',
53
+ override=True,
54
+ keep_ratio=True)
55
+ ]
56
+ ]),
57
+ dict(type='Normalize', **img_norm_cfg),
58
+ dict(type='Pad', size_divisor=32),
59
+ dict(type='DefaultFormatBundle'),
60
+ dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels', 'gt_masks']),
61
+ ]
62
+ data = dict(train=dict(pipeline=train_pipeline))
63
+
64
+ optimizer = dict(_delete_=True, type='AdamW', lr=0.0001, betas=(0.9, 0.999), weight_decay=0.05,
65
+ paramwise_cfg=dict(custom_keys={'absolute_pos_embed': dict(decay_mult=0.),
66
+ 'relative_position_bias_table': dict(decay_mult=0.),
67
+ 'norm': dict(decay_mult=0.)}))
68
+ lr_config = dict(step=[8, 11])
69
+ runner = dict(type='EpochBasedRunnerAmp', max_epochs=12)
70
+
71
+ # do not use mmdet version fp16
72
+ fp16 = None
73
+ optimizer_config = dict(
74
+ type="DistOptimizerHook",
75
+ update_interval=1,
76
+ grad_clip=None,
77
+ coalesce=True,
78
+ bucket_size_mb=-1,
79
+ use_fp16=True,
80
+ )
cwalt/CWALT.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Tue Oct 19 19:14:47 2021
5
+
6
+ @author: dinesh
7
+ """
8
+ import glob
9
+ from .utils import bb_intersection_over_union_unoccluded
10
+ import numpy as np
11
+ from PIL import Image
12
+ import datetime
13
+ import cv2
14
+ import os
15
+ from tqdm import tqdm
16
+
17
+
18
+ def get_image(time, folder):
19
+ for week_loop in range(5):
20
+ try:
21
+ image = np.array(Image.open(folder+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
22
+ break
23
+ except:
24
+ continue
25
+ if image is None:
26
+ print('file not found')
27
+ return image
28
+
29
+ def get_mask(segm, image):
30
+ poly = np.array(segm).reshape((int(len(segm)/2), 2))
31
+ mask = image.copy()*0
32
+ cv2.fillConvexPoly(mask, poly, (255, 255, 255))
33
+ return mask
34
+
35
+ def get_unoccluded(indices, tracks_all):
36
+ unoccluded_indexes = []
37
+ unoccluded_index_all =[]
38
+ while 1:
39
+ unoccluded_clusters = []
40
+ len_unocc = len(unoccluded_indexes)
41
+ for ind in indices:
42
+ if ind in unoccluded_indexes:
43
+ continue
44
+ occ = False
45
+ for ind_compare in indices:
46
+ if ind_compare in unoccluded_indexes:
47
+ continue
48
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare]) > 0.01 and ind_compare != ind:
49
+ occ = True
50
+ if occ==False:
51
+ unoccluded_indexes.extend([ind])
52
+ unoccluded_clusters.extend([ind])
53
+ if len(unoccluded_indexes) == len_unocc and len_unocc != 0:
54
+ for ind in indices:
55
+ if ind not in unoccluded_indexes:
56
+ unoccluded_indexes.extend([ind])
57
+ unoccluded_clusters.extend([ind])
58
+
59
+ unoccluded_index_all.append(unoccluded_clusters)
60
+ if len(unoccluded_indexes) > len(indices)-5:
61
+ break
62
+ return unoccluded_index_all
63
+
64
+ def primes(n): # simple sieve of multiples
65
+ odds = range(3, n+1, 2)
66
+ sieve = set(sum([list(range(q*q, n+1, q+q)) for q in odds], []))
67
+ return [2] + [p for p in odds if p not in sieve]
68
+
69
+ def save_image(image_read, save_path, data, path):
70
+ tracks = data['tracks_all_unoccluded']
71
+ segmentations = data['segmentation_all_unoccluded']
72
+ timestamps = data['timestamps_final_unoccluded']
73
+
74
+ image = image_read.copy()
75
+ indices = np.random.randint(len(tracks),size=30)
76
+ prime_numbers = primes(1000)
77
+ unoccluded_index_all = get_unoccluded(indices, tracks)
78
+
79
+ mask_stacked = image*0
80
+ mask_stacked_all =[]
81
+ count = 0
82
+ time = datetime.datetime.now()
83
+
84
+ for l in indices:
85
+ try:
86
+ image_crop = get_image(timestamps[l], path)
87
+ except:
88
+ continue
89
+ try:
90
+ bb_left, bb_top, bb_width, bb_height, confidence = tracks[l]
91
+ except:
92
+ bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[l]
93
+ mask = get_mask(segmentations[l], image)
94
+
95
+ image[mask > 0] = image_crop[mask > 0]
96
+ mask[mask > 0] = 1
97
+ for count, mask_inc in enumerate(mask_stacked_all):
98
+ mask_stacked_all[count][cv2.bitwise_and(mask, mask_inc) > 0] = 2
99
+ mask_stacked_all.append(mask)
100
+ mask_stacked += mask
101
+ count = count+1
102
+
103
+ cv2.imwrite(save_path + '/images/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', image[:, :, ::-1])
104
+ cv2.imwrite(save_path + '/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg', mask_stacked[:, :, ::-1]*30)
105
+ np.savez_compressed(save_path+'/Segmentation/'+str(time).replace(' ','T').replace(':','-').split('+')[0], mask=mask_stacked_all)
106
+
107
+ def CWALT_Generation(camera_name):
108
+ save_path_train = 'data/cwalt_train'
109
+ save_path_test = 'data/cwalt_test'
110
+
111
+ json_file_path = 'data/{}/{}.json'.format(camera_name,camera_name) # iii1/iii1_7_test.json' # './data.json'
112
+ path = 'data/' + camera_name
113
+
114
+ data = np.load(json_file_path + '.npz', allow_pickle=True)
115
+
116
+ ## slip data
117
+
118
+ data_train=dict()
119
+ data_test=dict()
120
+
121
+ split_index = int(len(data['timestamps_final_unoccluded'])*0.8)
122
+
123
+ data_train['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][0:split_index]
124
+ data_train['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][0:split_index]
125
+ data_train['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][0:split_index]
126
+
127
+ data_test['tracks_all_unoccluded'] = data['tracks_all_unoccluded'][split_index:]
128
+ data_test['segmentation_all_unoccluded'] = data['segmentation_all_unoccluded'][split_index:]
129
+ data_test['timestamps_final_unoccluded'] = data['timestamps_final_unoccluded'][split_index:]
130
+
131
+ image_read = np.array(Image.open(path + '/T18-median_image.jpg'))
132
+ image_read = cv2.resize(image_read, (int(image_read.shape[1]/2), int(image_read.shape[0]/2)))
133
+
134
+ try:
135
+ os.mkdir(save_path_train)
136
+ except:
137
+ print(save_path_train)
138
+
139
+ try:
140
+ os.mkdir(save_path_train + '/images')
141
+ os.mkdir(save_path_train + '/Segmentation')
142
+ except:
143
+ print(save_path_train+ '/images')
144
+
145
+ try:
146
+ os.mkdir(save_path_test)
147
+ except:
148
+ print(save_path_test)
149
+
150
+ try:
151
+ os.mkdir(save_path_test + '/images')
152
+ os.mkdir(save_path_test + '/Segmentation')
153
+ except:
154
+ print(save_path_test+ '/images')
155
+
156
+ for loop in tqdm(range(3000), desc="Generating training CWALT Images "):
157
+ save_image(image_read, save_path_train, data_train, path)
158
+
159
+ for loop in tqdm(range(300), desc="Generating testing CWALT Images "):
160
+ save_image(image_read, save_path_test, data_test, path)
161
+
cwalt/Clip_WALT_Generate.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri May 20 15:15:11 2022
5
+
6
+ @author: dinesh
7
+ """
8
+
9
+ from collections import OrderedDict
10
+ from matplotlib import pyplot as plt
11
+ from .utils import *
12
+ import scipy.interpolate
13
+
14
+ from scipy import interpolate
15
+ from .clustering_utils import *
16
+ import glob
17
+ import cv2
18
+ from PIL import Image
19
+
20
+
21
+ import json
22
+ import cv2
23
+
24
+ import numpy as np
25
+ from tqdm import tqdm
26
+
27
+
28
+ def ignore_indexes(tracks_all, labels_all):
29
+ # get repeating bounding boxes
30
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x == y]
31
+ ignore_ind = []
32
+ for index, track in enumerate(tracks_all):
33
+ print('in ignore', index, len(tracks_all))
34
+ if index in ignore_ind:
35
+ continue
36
+
37
+ if labels_all[index] < 1 or labels_all[index] > 3:
38
+ ignore_ind.extend([index])
39
+
40
+ ind = get_indexes(track, tracks_all)
41
+ if len(ind) > 30:
42
+ ignore_ind.extend(ind)
43
+
44
+ return ignore_ind
45
+
46
+ def repeated_indexes_old(tracks_all,ignore_ind, unoccluded_indexes=None):
47
+ # get repeating bounding boxes
48
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
49
+ repeat_ind = []
50
+ repeat_inds =[]
51
+ if unoccluded_indexes == None:
52
+ for index, track in enumerate(tracks_all):
53
+ if index in repeat_ind or index in ignore_ind:
54
+ continue
55
+ ind = get_indexes(track, tracks_all)
56
+ if len(ind) > 20:
57
+ repeat_ind.extend(ind)
58
+ repeat_inds.append([ind,track])
59
+ else:
60
+ for index in unoccluded_indexes:
61
+ if index in repeat_ind or index in ignore_ind:
62
+ continue
63
+ ind = get_indexes(tracks_all[index], tracks_all)
64
+ if len(ind) > 3:
65
+ repeat_ind.extend(ind)
66
+ repeat_inds.append([ind,tracks_all[index]])
67
+ return repeat_inds
68
+
69
+ def get_unoccluded_instances(timestamps_final, tracks_all, ignore_ind=[], threshold = 0.01):
70
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
71
+ unoccluded_indexes = []
72
+ time_checked = []
73
+ stationary_obj = []
74
+ count =0
75
+
76
+ for time in tqdm(np.unique(timestamps_final), desc="Detecting Unocclued objects in Image "):
77
+ count += 1
78
+ if [time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond] in time_checked:
79
+ analyze_bb = []
80
+ for ind in unoccluded_indexes_time:
81
+ for ind_compare in same_time_instances:
82
+ iou = bb_intersection_over_union(tracks_all[ind], tracks_all[ind_compare])
83
+ if iou < 0.5 and iou > 0:
84
+ analyze_bb.extend([ind_compare])
85
+ if iou > 0.99:
86
+ stationary_obj.extend([str(ind_compare)+'+'+str(ind)])
87
+
88
+ for ind in analyze_bb:
89
+ occ = False
90
+ for ind_compare in same_time_instances:
91
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind:
92
+ occ = True
93
+ break
94
+ if occ == False:
95
+ unoccluded_indexes.extend([ind])
96
+ continue
97
+
98
+ same_time_instances = get_indexes(time,timestamps_final)
99
+ unoccluded_indexes_time = []
100
+
101
+ for ind in same_time_instances:
102
+ if tracks_all[ind][4] < 0.9 or ind in ignore_ind:# or ind != 1859:
103
+ continue
104
+ occ = False
105
+ for ind_compare in same_time_instances:
106
+ if bb_intersection_over_union_unoccluded(tracks_all[ind], tracks_all[ind_compare], threshold=threshold) > threshold and ind_compare != ind and tracks_all[ind_compare][4] < 0.5:
107
+ occ = True
108
+ break
109
+ if occ==False:
110
+ unoccluded_indexes.extend([ind])
111
+ unoccluded_indexes_time.extend([ind])
112
+ time_checked.append([time.year,time.month, time.day, time.hour, time.minute, time.second, time.microsecond])
113
+ return unoccluded_indexes,stationary_obj
114
+
115
+ def visualize_unoccluded_detection(timestamps_final,tracks_all,segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name, ignore_ind=[]):
116
+ tracks_final = []
117
+ tracks_final.append([])
118
+ try:
119
+ os.mkdir(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/')
120
+ except:
121
+ print('Unoccluded debugging exists')
122
+
123
+ for time in tqdm(np.unique(timestamps_final), desc="Visualizing Unocclued objects in Image "):
124
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if x==y]
125
+ ind = get_indexes(time, timestamps_final)
126
+ image_unocc = False
127
+ for index in ind:
128
+ if index not in unoccluded_indexes:
129
+ continue
130
+ else:
131
+ image_unocc = True
132
+ break
133
+ if image_unocc == False:
134
+ continue
135
+
136
+ for week_loop in range(5):
137
+ try:
138
+ image = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
139
+ break
140
+ except:
141
+ continue
142
+
143
+ try:
144
+ mask = image*0
145
+ except:
146
+ print('image not found for ' + str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg' )
147
+ continue
148
+ image_original = image.copy()
149
+
150
+ for index in ind:
151
+ track = tracks_all[index]
152
+
153
+ if index in ignore_ind:
154
+ continue
155
+ if index not in unoccluded_indexes:
156
+ continue
157
+ try:
158
+ bb_left, bb_top, bb_width, bb_height, confidence, id = track
159
+ except:
160
+ bb_left, bb_top, bb_width, bb_height, confidence = track
161
+
162
+ if confidence > 0.6:
163
+ mask = poly_seg(image, segmentation_all[index])
164
+ cv2.imwrite(cwalt_data_path + '/' + camera_name+'_unoccluded_car_detection/' + str(index)+'.png', mask[:, :, ::-1])
165
+
166
+ def repeated_indexes(tracks_all,ignore_ind, repeat_count = 10, unoccluded_indexes=None):
167
+ get_indexes = lambda x, xs: [i for (y, i) in zip(xs, range(len(xs))) if bb_intersection_over_union(x, y) > 0.8 and i not in ignore_ind]
168
+ repeat_ind = []
169
+ repeat_inds =[]
170
+ if unoccluded_indexes == None:
171
+ for index, track in enumerate(tracks_all):
172
+ if index in repeat_ind or index in ignore_ind:
173
+ continue
174
+
175
+ ind = get_indexes(track, tracks_all)
176
+ if len(ind) > repeat_count:
177
+ repeat_ind.extend(ind)
178
+ repeat_inds.append([ind,track])
179
+ else:
180
+ for index in unoccluded_indexes:
181
+ if index in repeat_ind or index in ignore_ind:
182
+ continue
183
+ ind = get_indexes(tracks_all[index], tracks_all)
184
+ if len(ind) > repeat_count:
185
+ repeat_ind.extend(ind)
186
+ repeat_inds.append([ind,tracks_all[index]])
187
+
188
+
189
+ return repeat_inds
190
+
191
+ def poly_seg(image, segm):
192
+ poly = np.array(segm).reshape((int(len(segm)/2), 2))
193
+ overlay = image.copy()
194
+ alpha = 0.5
195
+ cv2.fillPoly(overlay, [poly], color=(255, 255, 0))
196
+ cv2.addWeighted(overlay, alpha, image, 1 - alpha, 0, image)
197
+ return image
198
+
199
+ def visualize_unoccuded_clusters(repeat_inds, tracks, segmentation_all, timestamps_final, cwalt_data_path):
200
+ for index_, repeat_ind in enumerate(repeat_inds):
201
+ image = np.array(Image.open(cwalt_data_path+'/'+'T18-median_image.jpg'))
202
+ try:
203
+ os.mkdir(cwalt_data_path+ '/Cwalt_database/')
204
+ except:
205
+ print('folder exists')
206
+ try:
207
+ os.mkdir(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
208
+ except:
209
+ print(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/')
210
+
211
+ for i in repeat_ind[0]:
212
+ try:
213
+ bb_left, bb_top, bb_width, bb_height, confidence = tracks[i]#bbox
214
+ except:
215
+ bb_left, bb_top, bb_width, bb_height, confidence, track_id = tracks[i]#bbox
216
+
217
+ cv2.rectangle(image,(int(bb_left), int(bb_top)),(int(bb_left+bb_width), int(bb_top+bb_height)),(0, 0, 255), 2)
218
+ time = timestamps_final[i]
219
+ for week_loop in range(5):
220
+ try:
221
+ image1 = np.array(Image.open(cwalt_data_path+'/week' +str(week_loop)+'/'+ str(time).replace(' ','T').replace(':','-').split('+')[0] + '.jpg'))
222
+ break
223
+ except:
224
+ continue
225
+
226
+ crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
227
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/o_' + str(i) +'.jpg', crop[:, :, ::-1])
228
+ image1 = poly_seg(image1,segmentation_all[i])
229
+ crop = image1[int(bb_top): int(bb_top + bb_height), int(bb_left):int(bb_left + bb_width)]
230
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'/' + str(i)+'.jpg', crop[:, :, ::-1])
231
+ if index_ > 100:
232
+ break
233
+
234
+ cv2.imwrite(cwalt_data_path+ '/Cwalt_database/' + str(index_) +'.jpg', image[:, :, ::-1])
235
+
236
+ def Get_unoccluded_objects(camera_name, debug = False, scale=True):
237
+ cwalt_data_path = 'data/' + camera_name
238
+ data_folder = cwalt_data_path
239
+ json_file_path = cwalt_data_path + '/' + camera_name + '.json'
240
+
241
+ with open(json_file_path, 'r') as j:
242
+ annotations = json.loads(j.read())
243
+
244
+ tracks_all = [parse_bbox(anno['bbox']) for anno in annotations]
245
+ segmentation_all = [parse_bbox(anno['segmentation']) for anno in annotations]
246
+ labels_all = [anno['label_id'] for anno in annotations]
247
+ timestamps_final = [parse(anno['time']) for anno in annotations]
248
+
249
+ if scale ==True:
250
+ scale_factor = 2
251
+ tracks_all_numpy = np.array(tracks_all)
252
+ tracks_all_numpy[:,:4] = np.array(tracks_all)[:,:4]/scale_factor
253
+ tracks_all = tracks_all_numpy.tolist()
254
+
255
+ segmentation_all_scaled = []
256
+ for list_loop in segmentation_all:
257
+ segmentation_all_scaled.append((np.floor_divide(np.array(list_loop),scale_factor)).tolist())
258
+ segmentation_all = segmentation_all_scaled
259
+
260
+ if debug == True:
261
+ timestamps_final = timestamps_final[:1000]
262
+ labels_all = labels_all[:1000]
263
+ segmentation_all = segmentation_all[:1000]
264
+ tracks_all = tracks_all[:1000]
265
+
266
+ unoccluded_indexes, stationary = get_unoccluded_instances(timestamps_final, tracks_all, threshold = 0.05)
267
+ if debug == True:
268
+ visualize_unoccluded_detection(timestamps_final, tracks_all, segmentation_all, unoccluded_indexes, cwalt_data_path, camera_name)
269
+
270
+ tracks_all_unoccluded = [tracks_all[i] for i in unoccluded_indexes]
271
+ segmentation_all_unoccluded = [segmentation_all[i] for i in unoccluded_indexes]
272
+ labels_all_unoccluded = [labels_all[i] for i in unoccluded_indexes]
273
+ timestamps_final_unoccluded = [timestamps_final[i] for i in unoccluded_indexes]
274
+ np.savez(json_file_path,tracks_all_unoccluded=tracks_all_unoccluded, segmentation_all_unoccluded=segmentation_all_unoccluded, labels_all_unoccluded=labels_all_unoccluded, timestamps_final_unoccluded=timestamps_final_unoccluded )
275
+
276
+ if debug == True:
277
+ repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=1)
278
+ visualize_unoccuded_clusters(repeat_inds_clusters, tracks_all_unoccluded, segmentation_all_unoccluded, timestamps_final_unoccluded, cwalt_data_path)
279
+ else:
280
+ repeat_inds_clusters = repeated_indexes(tracks_all_unoccluded,[], repeat_count=10)
281
+
282
+ np.savez(json_file_path + '_clubbed', repeat_inds=repeat_inds_clusters)
283
+ np.savez(json_file_path + '_stationary', stationary=stationary)
284
+
cwalt/Download_Detections.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ from psycopg2.extras import RealDictCursor
3
+ #import cv2
4
+ import psycopg2
5
+ import cv2
6
+
7
+
8
+ CONNECTION = "postgres://postgres:"
9
+
10
+ conn = psycopg2.connect(CONNECTION)
11
+ cursor = conn.cursor(cursor_factory=RealDictCursor)
12
+
13
+
14
+ def get_sample():
15
+ camera_name, camera_id = 'cam2', 4
16
+
17
+ print('Executing SQL command')
18
+
19
+ cursor.execute("SELECT * FROM annotations WHERE camera_id = {} and time >='2021-05-01 00:00:00' and time <='2021-05-07 23:59:50' and label_id in (1,2)".format(camera_id))
20
+
21
+ print('Dumping to json')
22
+ annotations = json.dumps(cursor.fetchall(), indent=2, default=str)
23
+ wjdata = json.loads(annotations)
24
+ with open('{}_{}_test.json'.format(camera_name, camera_id), 'w') as f:
25
+ json.dump(wjdata, f)
26
+ print('Done dumping to json')
27
+
28
+ get_sample()
cwalt/clustering_utils.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri May 20 15:18:20 2022
5
+
6
+ @author: dinesh
7
+ """
8
+
9
+ # 0 - Import related libraries
10
+
11
+ import urllib
12
+ import zipfile
13
+ import os
14
+ import scipy.io
15
+ import math
16
+ import numpy as np
17
+ import matplotlib.pyplot as plt
18
+ import seaborn as sns
19
+
20
+ from scipy.spatial.distance import directed_hausdorff
21
+ from sklearn.cluster import DBSCAN
22
+ from sklearn.metrics.pairwise import pairwise_distances
23
+ import scipy.spatial.distance
24
+
25
+ from .kmedoid import kMedoids # kMedoids code is adapted from https://github.com/letiantian/kmedoids
26
+
27
+ # Some visualization stuff, not so important
28
+ # sns.set()
29
+ plt.rcParams['figure.figsize'] = (12, 12)
30
+
31
+ # Utility Functions
32
+
33
+ color_lst = plt.rcParams['axes.prop_cycle'].by_key()['color']
34
+ color_lst.extend(['firebrick', 'olive', 'indigo', 'khaki', 'teal', 'saddlebrown',
35
+ 'skyblue', 'coral', 'darkorange', 'lime', 'darkorchid', 'dimgray'])
36
+
37
+
38
+ def plot_cluster(image, traj_lst, cluster_lst):
39
+ '''
40
+ Plots given trajectories with a color that is specific for every trajectory's own cluster index.
41
+ Outlier trajectories which are specified with -1 in `cluster_lst` are plotted dashed with black color
42
+ '''
43
+ cluster_count = np.max(cluster_lst) + 1
44
+
45
+ for traj, cluster in zip(traj_lst, cluster_lst):
46
+
47
+ # if cluster == -1:
48
+ # # Means it it a noisy trajectory, paint it black
49
+ # plt.plot(traj[:, 0], traj[:, 1], c='k', linestyle='dashed')
50
+ #
51
+ # else:
52
+ plt.plot(traj[:, 0], traj[:, 1], c=color_lst[cluster % len(color_lst)])
53
+
54
+ plt.imshow(image)
55
+ # plt.show()
56
+ plt.axis('off')
57
+ plt.savefig('trajectory.png', bbox_inches='tight')
58
+ plt.show()
59
+
60
+
61
+ # 3 - Distance matrix
62
+
63
+ def hausdorff( u, v):
64
+ d = max(directed_hausdorff(u, v)[0], directed_hausdorff(v, u)[0])
65
+ return d
66
+
67
+
68
+ def build_distance_matrix(traj_lst):
69
+ # 2 - Trajectory segmentation
70
+
71
+ print('Running trajectory segmentation...')
72
+ degree_threshold = 5
73
+
74
+ for traj_index, traj in enumerate(traj_lst):
75
+
76
+ hold_index_lst = []
77
+ previous_azimuth = 1000
78
+
79
+ for point_index, point in enumerate(traj[:-1]):
80
+ next_point = traj[point_index + 1]
81
+ diff_vector = next_point - point
82
+ azimuth = (math.degrees(math.atan2(*diff_vector)) + 360) % 360
83
+
84
+ if abs(azimuth - previous_azimuth) > degree_threshold:
85
+ hold_index_lst.append(point_index)
86
+ previous_azimuth = azimuth
87
+ hold_index_lst.append(traj.shape[0] - 1) # Last point of trajectory is always added
88
+
89
+ traj_lst[traj_index] = traj[hold_index_lst, :]
90
+
91
+ print('Building distance matrix...')
92
+ traj_count = len(traj_lst)
93
+ D = np.zeros((traj_count, traj_count))
94
+
95
+ # This may take a while
96
+ for i in range(traj_count):
97
+ if i % 20 == 0:
98
+ print(i)
99
+ for j in range(i + 1, traj_count):
100
+ distance = hausdorff(traj_lst[i], traj_lst[j])
101
+ D[i, j] = distance
102
+ D[j, i] = distance
103
+
104
+ return D
105
+
106
+
107
+ def run_kmedoids(image, traj_lst, D):
108
+ # 4 - Different clustering methods
109
+
110
+ # 4.1 - kmedoids
111
+
112
+ traj_count = len(traj_lst)
113
+
114
+ k = 3 # The number of clusters
115
+ medoid_center_lst, cluster2index_lst = kMedoids(D, k)
116
+
117
+ cluster_lst = np.empty((traj_count,), dtype=int)
118
+
119
+ for cluster in cluster2index_lst:
120
+ cluster_lst[cluster2index_lst[cluster]] = cluster
121
+
122
+ plot_cluster(image, traj_lst, cluster_lst)
123
+
124
+
125
+ def run_dbscan(image, traj_lst, D):
126
+ mdl = DBSCAN(eps=400, min_samples=10)
127
+ cluster_lst = mdl.fit_predict(D)
128
+
129
+ plot_cluster(image, traj_lst, cluster_lst)
130
+
131
+
132
+
cwalt/kmedoid.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri May 20 15:18:56 2022
5
+
6
+ @author: dinesh
7
+ """
8
+
9
+ import numpy as np
10
+ import math
11
+
12
+ def kMedoids(D, k, tmax=100):
13
+ # determine dimensions of distance matrix D
14
+ m, n = D.shape
15
+
16
+ np.fill_diagonal(D, math.inf)
17
+
18
+ if k > n:
19
+ raise Exception('too many medoids')
20
+ # randomly initialize an array of k medoid indices
21
+ M = np.arange(n)
22
+ np.random.shuffle(M)
23
+ M = np.sort(M[:k])
24
+
25
+ # create a copy of the array of medoid indices
26
+ Mnew = np.copy(M)
27
+
28
+ # initialize a dictionary to represent clusters
29
+ C = {}
30
+ for t in range(tmax):
31
+ # determine clusters, i. e. arrays of data indices
32
+ J = np.argmin(D[:,M], axis=1)
33
+
34
+ for kappa in range(k):
35
+ C[kappa] = np.where(J==kappa)[0]
36
+ # update cluster medoids
37
+ for kappa in range(k):
38
+ J = np.mean(D[np.ix_(C[kappa],C[kappa])],axis=1)
39
+ j = np.argmin(J)
40
+ Mnew[kappa] = C[kappa][j]
41
+ np.sort(Mnew)
42
+ # check for convergence
43
+ if np.array_equal(M, Mnew):
44
+ break
45
+ M = np.copy(Mnew)
46
+ else:
47
+ # final update of cluster memberships
48
+ J = np.argmin(D[:,M], axis=1)
49
+ for kappa in range(k):
50
+ C[kappa] = np.where(J==kappa)[0]
51
+
52
+ np.fill_diagonal(D, 0)
53
+
54
+ # return results
55
+ return M, C
cwalt/utils.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Fri May 20 15:16:56 2022
5
+
6
+ @author: dinesh
7
+ """
8
+
9
+ import json
10
+ import cv2
11
+ from PIL import Image
12
+ import numpy as np
13
+ from dateutil.parser import parse
14
+
15
+ def bb_intersection_over_union(box1, box2):
16
+ #print(box1, box2)
17
+ boxA = box1.copy()
18
+ boxB = box2.copy()
19
+ boxA[2] = boxA[0]+boxA[2]
20
+ boxA[3] = boxA[1]+boxA[3]
21
+ boxB[2] = boxB[0]+boxB[2]
22
+ boxB[3] = boxB[1]+boxB[3]
23
+ # determine the (x, y)-coordinates of the intersection rectangle
24
+ xA = max(boxA[0], boxB[0])
25
+ yA = max(boxA[1], boxB[1])
26
+ xB = min(boxA[2], boxB[2])
27
+ yB = min(boxA[3], boxB[3])
28
+
29
+ # compute the area of intersection rectangle
30
+ interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
31
+
32
+ if interArea == 0:
33
+ return 0
34
+ # compute the area of both the prediction and ground-truth
35
+ # rectangles
36
+ boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
37
+ boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
38
+
39
+ # compute the intersection over union by taking the intersection
40
+ # area and dividing it by the sum of prediction + ground-truth
41
+ # areas - the interesection area
42
+ iou = interArea / float(boxAArea + boxBArea - interArea)
43
+ return iou
44
+
45
+ def bb_intersection_over_union_unoccluded(box1, box2, threshold=0.01):
46
+ #print(box1, box2)
47
+ boxA = box1.copy()
48
+ boxB = box2.copy()
49
+ boxA[2] = boxA[0]+boxA[2]
50
+ boxA[3] = boxA[1]+boxA[3]
51
+ boxB[2] = boxB[0]+boxB[2]
52
+ boxB[3] = boxB[1]+boxB[3]
53
+ # determine the (x, y)-coordinates of the intersection rectangle
54
+ xA = max(boxA[0], boxB[0])
55
+ yA = max(boxA[1], boxB[1])
56
+ xB = min(boxA[2], boxB[2])
57
+ yB = min(boxA[3], boxB[3])
58
+
59
+ # compute the area of intersection rectangle
60
+ interArea = abs(max((xB - xA, 0)) * max((yB - yA), 0))
61
+
62
+ if interArea == 0:
63
+ return 0
64
+ # compute the area of both the prediction and ground-truth
65
+ # rectangles
66
+ boxAArea = abs((boxA[2] - boxA[0]) * (boxA[3] - boxA[1]))
67
+ boxBArea = abs((boxB[2] - boxB[0]) * (boxB[3] - boxB[1]))
68
+
69
+ # compute the intersection over union by taking the intersection
70
+ # area and dividing it by the sum of prediction + ground-truth
71
+ # areas - the interesection area
72
+ iou = interArea / float(boxAArea + boxBArea - interArea)
73
+
74
+ #print(iou)
75
+ # return the intersection over union value
76
+ occlusion = False
77
+ if iou > threshold and iou < 1:
78
+ #print(boxA[3], boxB[3], boxB[1])
79
+ if boxA[3] < boxB[3]:# and boxA[3] > boxB[1]:
80
+ if boxB[2] > boxA[0]:# and boxB[2] < boxA[2]:
81
+ #print('first', (boxB[2] - boxA[0])/(boxA[2] - boxA[0]))
82
+ if (min(boxB[2],boxA[2]) - boxA[0])/(boxA[2] - boxA[0]) > threshold:
83
+ occlusion = True
84
+
85
+ if boxB[0] < boxA[2]: # boxB[0] > boxA[0] and
86
+ #print('second', (boxA[2] - boxB[0])/(boxA[2] - boxA[0]))
87
+ if (boxA[2] - max(boxB[0],boxA[0]))/(boxA[2] - boxA[0]) > threshold:
88
+ occlusion = True
89
+ if occlusion == False:
90
+ iou = iou*0
91
+ #asas
92
+ # asas
93
+ #iou = 0.9 #iou*0
94
+ #print(box1, box2, iou, occlusion)
95
+ return iou
96
+ def draw_tracks(image, tracks):
97
+ """
98
+ Draw on input image.
99
+
100
+ Args:
101
+ image (numpy.ndarray): image
102
+ tracks (list): list of tracks to be drawn on the image.
103
+
104
+ Returns:
105
+ numpy.ndarray: image with the track-ids drawn on it.
106
+ """
107
+
108
+ for trk in tracks:
109
+
110
+ trk_id = trk[1]
111
+ xmin = trk[2]
112
+ ymin = trk[3]
113
+ width = trk[4]
114
+ height = trk[5]
115
+
116
+ xcentroid, ycentroid = int(xmin + 0.5*width), int(ymin + 0.5*height)
117
+
118
+ text = "ID {}".format(trk_id)
119
+
120
+ cv2.putText(image, text, (xcentroid - 10, ycentroid - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
121
+ cv2.circle(image, (xcentroid, ycentroid), 4, (0, 255, 0), -1)
122
+
123
+ return image
124
+
125
+
126
+ def draw_bboxes(image, tracks):
127
+ """
128
+ Draw the bounding boxes about detected objects in the image.
129
+
130
+ Args:
131
+ image (numpy.ndarray): Image or video frame.
132
+ bboxes (numpy.ndarray): Bounding boxes pixel coordinates as (xmin, ymin, width, height)
133
+ confidences (numpy.ndarray): Detection confidence or detection probability.
134
+ class_ids (numpy.ndarray): Array containing class ids (aka label ids) of each detected object.
135
+
136
+ Returns:
137
+ numpy.ndarray: image with the bounding boxes drawn on it.
138
+ """
139
+
140
+ for trk in tracks:
141
+ xmin = int(trk[2])
142
+ ymin = int(trk[3])
143
+ width = int(trk[4])
144
+ height = int(trk[5])
145
+ clr = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255))
146
+ cv2.rectangle(image, (xmin, ymin), (xmin + width, ymin + height), clr, 2)
147
+
148
+ return image
149
+
150
+
151
+ def num(v):
152
+ number_as_float = float(v)
153
+ number_as_int = int(number_as_float)
154
+ return number_as_int if number_as_float == number_as_int else number_as_float
155
+
156
+
157
+ def parse_bbox(bbox_str):
158
+ bbox_list = bbox_str.strip('{').strip('}').split(',')
159
+ bbox_list = [num(elem) for elem in bbox_list]
160
+ return bbox_list
161
+
162
+ def parse_seg(bbox_str):
163
+ bbox_list = bbox_str.strip('{').strip('}').split(',')
164
+ bbox_list = [num(elem) for elem in bbox_list]
165
+ ret = bbox_list # []
166
+ # for i in range(0, len(bbox_list) - 1, 2):
167
+ # ret.append((bbox_list[i], bbox_list[i + 1]))
168
+ return ret
cwalt_generate.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ Created on Sat Jun 4 16:55:58 2022
5
+
6
+ @author: dinesh
7
+ """
8
+ from cwalt.CWALT import CWALT_Generation
9
+ from cwalt.Clip_WALT_Generate import Get_unoccluded_objects
10
+
11
+ if __name__ == '__main__':
12
+ camera_name = 'cam2'
13
+ Get_unoccluded_objects(camera_name)
14
+ CWALT_Generation(camera_name)
docker/Dockerfile ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ARG PYTORCH="1.9.0"
2
+ ARG CUDA="11.1"
3
+ ARG CUDNN="8"
4
+
5
+ FROM pytorch/pytorch:${PYTORCH}-cuda${CUDA}-cudnn${CUDNN}-devel
6
+
7
+ ENV TORCH_CUDA_ARCH_LIST="6.0 6.1 7.0+PTX"
8
+ ENV TORCH_NVCC_FLAGS="-Xfatbin -compress-all"
9
+ ENV CMAKE_PREFIX_PATH="$(dirname $(which conda))/../"
10
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/cuda/repos/ubuntu1804/x86_64/3bf863cc.pub
11
+ RUN apt-key adv --fetch-keys https://developer.download.nvidia.com/compute/machine-learning/repos/ubuntu1804/x86_64/7fa2af80.pub
12
+ RUN apt-get update && apt-get install -y ffmpeg libsm6 libxext6 git ninja-build libglib2.0-0 libsm6 libxrender-dev libxext6 \
13
+ && apt-get clean \
14
+ && rm -rf /var/lib/apt/lists/*
15
+
16
+ # Install MMCV
17
+ #RUN pip install mmcv-full==1.3.8 -f https://download.openmmlab.com/mmcv/dist/cu101/torch1.6.0/index.html
18
+ # -f https://openmmlab.oss-accelerate.aliyuncs.com/mmcv/dist/index.html
19
+ RUN pip install mmcv-full==1.4.0 -f https://download.openmmlab.com/mmcv/dist/cu111/torch1.9.0/index.html
20
+ # Install MMDetection
21
+ RUN conda clean --all
22
+ RUN git clone https://github.com/open-mmlab/mmdetection.git /mmdetection
23
+ WORKDIR /mmdetection
24
+ ENV FORCE_CUDA="1"
25
+ RUN cd /mmdetection && git checkout 7bd39044f35aec4b90dd797b965777541a8678ff
26
+ RUN pip install -r requirements/build.txt
27
+ RUN pip install --no-cache-dir -e .
28
+ RUN apt-get update
29
+ RUN apt-get install -y vim
30
+ RUN pip uninstall -y pycocotools
31
+ RUN pip install mmpycocotools timm scikit-image imagesize
32
+
33
+
34
+ # make sure we don't overwrite some existing directory called "apex"
35
+ WORKDIR /tmp/unique_for_apex
36
+ # uninstall Apex if present, twice to make absolutely sure :)
37
+ RUN pip uninstall -y apex || :
38
+ RUN pip uninstall -y apex || :
39
+ # SHA is something the user can touch to force recreation of this Docker layer,
40
+ # and therefore force cloning of the latest version of Apex
41
+ RUN SHA=ToUcHMe git clone https://github.com/NVIDIA/apex.git
42
+ WORKDIR /tmp/unique_for_apex/apex
43
+ RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
44
+ RUN pip install seaborn sklearn imantics gradio
45
+ WORKDIR /code
46
+ ENTRYPOINT ["python", "app.py"]
47
+
48
+ #RUN git clone https://github.com/NVIDIA/apex
49
+ #RUN cd apex
50
+ #RUN pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" .
51
+ #RUN pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./
52
+
github_vis/cwalt.gif ADDED
github_vis/vis_cars.gif ADDED
github_vis/vis_people.gif ADDED
infer.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from argparse import ArgumentParser
2
+
3
+ from mmdet.apis import inference_detector, init_detector, show_result_pyplot
4
+ from mmdet.core.mask.utils import encode_mask_results
5
+ import numpy as np
6
+ import mmcv
7
+ import torch
8
+ from imantics import Polygons, Mask
9
+ import json
10
+ import os
11
+ import cv2, glob
12
+
13
+ class detections():
14
+ def __init__(self, cfg_path, device, model_path = 'data/models/walt_vehicle.pth', threshold=0.85):
15
+ self.model = init_detector(cfg_path, model_path, device=device)
16
+ self.all_preds = []
17
+ self.all_scores = []
18
+ self.index = []
19
+ self.score_thr = threshold
20
+ self.result = []
21
+ self.record_dict = {'model': cfg_path,'results': []}
22
+ self.detect_count = []
23
+
24
+
25
+ def run_on_image(self, image):
26
+ self.result = inference_detector(self.model, image)
27
+ image_labelled = self.model.show_result(image, self.result, score_thr=self.score_thr)
28
+ return image_labelled
29
+
30
+ def process_output(self, count):
31
+ result = self.result
32
+ infer_result = {'url': count,
33
+ 'boxes': [],
34
+ 'scores': [],
35
+ 'keypoints': [],
36
+ 'segmentation': [],
37
+ 'label_ids': [],
38
+ 'track': [],
39
+ 'labels': []}
40
+
41
+ if isinstance(result, tuple):
42
+ bbox_result, segm_result = result
43
+ #segm_result = encode_mask_results(segm_result)
44
+ if isinstance(segm_result, tuple):
45
+ segm_result = segm_result[0] # ms rcnn
46
+ bboxes = np.vstack(bbox_result)
47
+ labels = [np.full(bbox.shape[0], i, dtype=np.int32) for i, bbox in enumerate(bbox_result)]
48
+
49
+ labels = np.concatenate(labels)
50
+ segms = None
51
+ if segm_result is not None and len(labels) > 0: # non empty
52
+ segms = mmcv.concat_list(segm_result)
53
+ if isinstance(segms[0], torch.Tensor):
54
+ segms = torch.stack(segms, dim=0).detach().cpu().numpy()
55
+ else:
56
+ segms = np.stack(segms, axis=0)
57
+
58
+ for i, (bbox, label, segm) in enumerate(zip(bboxes, labels, segms)):
59
+ if bbox[-1].item() <0.3:
60
+ continue
61
+ box = [bbox[0].item(), bbox[1].item(), bbox[2].item(), bbox[3].item()]
62
+ polygons = Mask(segm).polygons()
63
+
64
+ infer_result['boxes'].append(box)
65
+ infer_result['segmentation'].append(polygons.segmentation)
66
+ infer_result['scores'].append(bbox[-1].item())
67
+ infer_result['labels'].append(self.model.CLASSES[label])
68
+ infer_result['label_ids'].append(label)
69
+ self.record_dict['results'].append(infer_result)
70
+ self.detect_count = labels
71
+
72
+ def write_json(self, filename):
73
+ with open(filename + '.json', 'w') as f:
74
+ json.dump(self.record_dict, f)
75
+
76
+
77
+ def main():
78
+ if torch.cuda.is_available() == False:
79
+ device='cpu'
80
+ else:
81
+ device='cuda:0'
82
+ detect_people = detections('configs/walt/walt_people.py', device, model_path='data/models/walt_people.pth')
83
+ detect = detections('configs/walt/walt_vehicle.py', device, model_path='data/models/walt_vehicle.pth')
84
+ filenames = sorted(glob.glob('demo/images/*'))
85
+ count = 0
86
+ for filename in filenames:
87
+ img=cv2.imread(filename)
88
+ try:
89
+ img = detect_people.run_on_image(img)
90
+ img = detect.run_on_image(img)
91
+ except:
92
+ continue
93
+ count=count+1
94
+
95
+ try:
96
+ import os
97
+ os.makedirs(os.path.dirname(filename.replace('demo','demo/results/')))
98
+ os.mkdirs(os.path.dirname(filename))
99
+ except:
100
+ print('done')
101
+ cv2.imwrite(filename.replace('demo','demo/results/'),img)
102
+ if count == 30000:
103
+ break
104
+ try:
105
+ detect.process_output(count)
106
+ except:
107
+ continue
108
+ '''
109
+
110
+ np.savez('FC', a= detect.record_dict)
111
+ with open('check.json', 'w') as f:
112
+ json.dump(detect.record_dict, f)
113
+ detect.write_json('seq3')
114
+ asas
115
+ detect.process_output(0)
116
+ '''
117
+ if __name__ == "__main__":
118
+ main()
mmcv_custom/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ from .checkpoint import load_checkpoint
4
+
5
+ __all__ = ['load_checkpoint']
mmcv_custom/checkpoint.py ADDED
@@ -0,0 +1,500 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import io
3
+ import os
4
+ import os.path as osp
5
+ import pkgutil
6
+ import time
7
+ import warnings
8
+ from collections import OrderedDict
9
+ from importlib import import_module
10
+ from tempfile import TemporaryDirectory
11
+
12
+ import torch
13
+ import torchvision
14
+ from torch.optim import Optimizer
15
+ from torch.utils import model_zoo
16
+ from torch.nn import functional as F
17
+
18
+ import mmcv
19
+ from mmcv.fileio import FileClient
20
+ from mmcv.fileio import load as load_file
21
+ from mmcv.parallel import is_module_wrapper
22
+ from mmcv.utils import mkdir_or_exist
23
+ from mmcv.runner import get_dist_info
24
+
25
+ ENV_MMCV_HOME = 'MMCV_HOME'
26
+ ENV_XDG_CACHE_HOME = 'XDG_CACHE_HOME'
27
+ DEFAULT_CACHE_DIR = '~/.cache'
28
+
29
+
30
+ def _get_mmcv_home():
31
+ mmcv_home = os.path.expanduser(
32
+ os.getenv(
33
+ ENV_MMCV_HOME,
34
+ os.path.join(
35
+ os.getenv(ENV_XDG_CACHE_HOME, DEFAULT_CACHE_DIR), 'mmcv')))
36
+
37
+ mkdir_or_exist(mmcv_home)
38
+ return mmcv_home
39
+
40
+
41
+ def load_state_dict(module, state_dict, strict=False, logger=None):
42
+ """Load state_dict to a module.
43
+
44
+ This method is modified from :meth:`torch.nn.Module.load_state_dict`.
45
+ Default value for ``strict`` is set to ``False`` and the message for
46
+ param mismatch will be shown even if strict is False.
47
+
48
+ Args:
49
+ module (Module): Module that receives the state_dict.
50
+ state_dict (OrderedDict): Weights.
51
+ strict (bool): whether to strictly enforce that the keys
52
+ in :attr:`state_dict` match the keys returned by this module's
53
+ :meth:`~torch.nn.Module.state_dict` function. Default: ``False``.
54
+ logger (:obj:`logging.Logger`, optional): Logger to log the error
55
+ message. If not specified, print function will be used.
56
+ """
57
+ unexpected_keys = []
58
+ all_missing_keys = []
59
+ err_msg = []
60
+
61
+ metadata = getattr(state_dict, '_metadata', None)
62
+ state_dict = state_dict.copy()
63
+ if metadata is not None:
64
+ state_dict._metadata = metadata
65
+
66
+ # use _load_from_state_dict to enable checkpoint version control
67
+ def load(module, prefix=''):
68
+ # recursively check parallel module in case that the model has a
69
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
70
+ if is_module_wrapper(module):
71
+ module = module.module
72
+ local_metadata = {} if metadata is None else metadata.get(
73
+ prefix[:-1], {})
74
+ module._load_from_state_dict(state_dict, prefix, local_metadata, True,
75
+ all_missing_keys, unexpected_keys,
76
+ err_msg)
77
+ for name, child in module._modules.items():
78
+ if child is not None:
79
+ load(child, prefix + name + '.')
80
+
81
+ load(module)
82
+ load = None # break load->load reference cycle
83
+
84
+ # ignore "num_batches_tracked" of BN layers
85
+ missing_keys = [
86
+ key for key in all_missing_keys if 'num_batches_tracked' not in key
87
+ ]
88
+
89
+ if unexpected_keys:
90
+ err_msg.append('unexpected key in source '
91
+ f'state_dict: {", ".join(unexpected_keys)}\n')
92
+ if missing_keys:
93
+ err_msg.append(
94
+ f'missing keys in source state_dict: {", ".join(missing_keys)}\n')
95
+
96
+ rank, _ = get_dist_info()
97
+ if len(err_msg) > 0 and rank == 0:
98
+ err_msg.insert(
99
+ 0, 'The model and loaded state dict do not match exactly\n')
100
+ err_msg = '\n'.join(err_msg)
101
+ if strict:
102
+ raise RuntimeError(err_msg)
103
+ elif logger is not None:
104
+ logger.warning(err_msg)
105
+ else:
106
+ print(err_msg)
107
+
108
+
109
+ def load_url_dist(url, model_dir=None):
110
+ """In distributed setting, this function only download checkpoint at local
111
+ rank 0."""
112
+ rank, world_size = get_dist_info()
113
+ rank = int(os.environ.get('LOCAL_RANK', rank))
114
+ if rank == 0:
115
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
116
+ if world_size > 1:
117
+ torch.distributed.barrier()
118
+ if rank > 0:
119
+ checkpoint = model_zoo.load_url(url, model_dir=model_dir)
120
+ return checkpoint
121
+
122
+
123
+ def load_pavimodel_dist(model_path, map_location=None):
124
+ """In distributed setting, this function only download checkpoint at local
125
+ rank 0."""
126
+ try:
127
+ from pavi import modelcloud
128
+ except ImportError:
129
+ raise ImportError(
130
+ 'Please install pavi to load checkpoint from modelcloud.')
131
+ rank, world_size = get_dist_info()
132
+ rank = int(os.environ.get('LOCAL_RANK', rank))
133
+ if rank == 0:
134
+ model = modelcloud.get(model_path)
135
+ with TemporaryDirectory() as tmp_dir:
136
+ downloaded_file = osp.join(tmp_dir, model.name)
137
+ model.download(downloaded_file)
138
+ checkpoint = torch.load(downloaded_file, map_location=map_location)
139
+ if world_size > 1:
140
+ torch.distributed.barrier()
141
+ if rank > 0:
142
+ model = modelcloud.get(model_path)
143
+ with TemporaryDirectory() as tmp_dir:
144
+ downloaded_file = osp.join(tmp_dir, model.name)
145
+ model.download(downloaded_file)
146
+ checkpoint = torch.load(
147
+ downloaded_file, map_location=map_location)
148
+ return checkpoint
149
+
150
+
151
+ def load_fileclient_dist(filename, backend, map_location):
152
+ """In distributed setting, this function only download checkpoint at local
153
+ rank 0."""
154
+ rank, world_size = get_dist_info()
155
+ rank = int(os.environ.get('LOCAL_RANK', rank))
156
+ allowed_backends = ['ceph']
157
+ if backend not in allowed_backends:
158
+ raise ValueError(f'Load from Backend {backend} is not supported.')
159
+ if rank == 0:
160
+ fileclient = FileClient(backend=backend)
161
+ buffer = io.BytesIO(fileclient.get(filename))
162
+ checkpoint = torch.load(buffer, map_location=map_location)
163
+ if world_size > 1:
164
+ torch.distributed.barrier()
165
+ if rank > 0:
166
+ fileclient = FileClient(backend=backend)
167
+ buffer = io.BytesIO(fileclient.get(filename))
168
+ checkpoint = torch.load(buffer, map_location=map_location)
169
+ return checkpoint
170
+
171
+
172
+ def get_torchvision_models():
173
+ model_urls = dict()
174
+ for _, name, ispkg in pkgutil.walk_packages(torchvision.models.__path__):
175
+ if ispkg:
176
+ continue
177
+ _zoo = import_module(f'torchvision.models.{name}')
178
+ if hasattr(_zoo, 'model_urls'):
179
+ _urls = getattr(_zoo, 'model_urls')
180
+ model_urls.update(_urls)
181
+ return model_urls
182
+
183
+
184
+ def get_external_models():
185
+ mmcv_home = _get_mmcv_home()
186
+ default_json_path = osp.join(mmcv.__path__[0], 'model_zoo/open_mmlab.json')
187
+ default_urls = load_file(default_json_path)
188
+ assert isinstance(default_urls, dict)
189
+ external_json_path = osp.join(mmcv_home, 'open_mmlab.json')
190
+ if osp.exists(external_json_path):
191
+ external_urls = load_file(external_json_path)
192
+ assert isinstance(external_urls, dict)
193
+ default_urls.update(external_urls)
194
+
195
+ return default_urls
196
+
197
+
198
+ def get_mmcls_models():
199
+ mmcls_json_path = osp.join(mmcv.__path__[0], 'model_zoo/mmcls.json')
200
+ mmcls_urls = load_file(mmcls_json_path)
201
+
202
+ return mmcls_urls
203
+
204
+
205
+ def get_deprecated_model_names():
206
+ deprecate_json_path = osp.join(mmcv.__path__[0],
207
+ 'model_zoo/deprecated.json')
208
+ deprecate_urls = load_file(deprecate_json_path)
209
+ assert isinstance(deprecate_urls, dict)
210
+
211
+ return deprecate_urls
212
+
213
+
214
+ def _process_mmcls_checkpoint(checkpoint):
215
+ state_dict = checkpoint['state_dict']
216
+ new_state_dict = OrderedDict()
217
+ for k, v in state_dict.items():
218
+ if k.startswith('backbone.'):
219
+ new_state_dict[k[9:]] = v
220
+ new_checkpoint = dict(state_dict=new_state_dict)
221
+
222
+ return new_checkpoint
223
+
224
+
225
+ def _load_checkpoint(filename, map_location=None):
226
+ """Load checkpoint from somewhere (modelzoo, file, url).
227
+
228
+ Args:
229
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
230
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
231
+ details.
232
+ map_location (str | None): Same as :func:`torch.load`. Default: None.
233
+
234
+ Returns:
235
+ dict | OrderedDict: The loaded checkpoint. It can be either an
236
+ OrderedDict storing model weights or a dict containing other
237
+ information, which depends on the checkpoint.
238
+ """
239
+ if filename.startswith('modelzoo://'):
240
+ warnings.warn('The URL scheme of "modelzoo://" is deprecated, please '
241
+ 'use "torchvision://" instead')
242
+ model_urls = get_torchvision_models()
243
+ model_name = filename[11:]
244
+ checkpoint = load_url_dist(model_urls[model_name])
245
+ elif filename.startswith('torchvision://'):
246
+ model_urls = get_torchvision_models()
247
+ model_name = filename[14:]
248
+ checkpoint = load_url_dist(model_urls[model_name])
249
+ elif filename.startswith('open-mmlab://'):
250
+ model_urls = get_external_models()
251
+ model_name = filename[13:]
252
+ deprecated_urls = get_deprecated_model_names()
253
+ if model_name in deprecated_urls:
254
+ warnings.warn(f'open-mmlab://{model_name} is deprecated in favor '
255
+ f'of open-mmlab://{deprecated_urls[model_name]}')
256
+ model_name = deprecated_urls[model_name]
257
+ model_url = model_urls[model_name]
258
+ # check if is url
259
+ if model_url.startswith(('http://', 'https://')):
260
+ checkpoint = load_url_dist(model_url)
261
+ else:
262
+ filename = osp.join(_get_mmcv_home(), model_url)
263
+ if not osp.isfile(filename):
264
+ raise IOError(f'{filename} is not a checkpoint file')
265
+ checkpoint = torch.load(filename, map_location=map_location)
266
+ elif filename.startswith('mmcls://'):
267
+ model_urls = get_mmcls_models()
268
+ model_name = filename[8:]
269
+ checkpoint = load_url_dist(model_urls[model_name])
270
+ checkpoint = _process_mmcls_checkpoint(checkpoint)
271
+ elif filename.startswith(('http://', 'https://')):
272
+ checkpoint = load_url_dist(filename)
273
+ elif filename.startswith('pavi://'):
274
+ model_path = filename[7:]
275
+ checkpoint = load_pavimodel_dist(model_path, map_location=map_location)
276
+ elif filename.startswith('s3://'):
277
+ checkpoint = load_fileclient_dist(
278
+ filename, backend='ceph', map_location=map_location)
279
+ else:
280
+ if not osp.isfile(filename):
281
+ raise IOError(f'{filename} is not a checkpoint file')
282
+ checkpoint = torch.load(filename, map_location=map_location)
283
+ return checkpoint
284
+
285
+
286
+ def load_checkpoint(model,
287
+ filename,
288
+ map_location='cpu',
289
+ strict=False,
290
+ logger=None):
291
+ """Load checkpoint from a file or URI.
292
+
293
+ Args:
294
+ model (Module): Module to load checkpoint.
295
+ filename (str): Accept local filepath, URL, ``torchvision://xxx``,
296
+ ``open-mmlab://xxx``. Please refer to ``docs/model_zoo.md`` for
297
+ details.
298
+ map_location (str): Same as :func:`torch.load`.
299
+ strict (bool): Whether to allow different params for the model and
300
+ checkpoint.
301
+ logger (:mod:`logging.Logger` or None): The logger for error message.
302
+
303
+ Returns:
304
+ dict or OrderedDict: The loaded checkpoint.
305
+ """
306
+ checkpoint = _load_checkpoint(filename, map_location)
307
+ # OrderedDict is a subclass of dict
308
+ if not isinstance(checkpoint, dict):
309
+ raise RuntimeError(
310
+ f'No state_dict found in checkpoint file {filename}')
311
+ # get state_dict from checkpoint
312
+ if 'state_dict' in checkpoint:
313
+ state_dict = checkpoint['state_dict']
314
+ elif 'model' in checkpoint:
315
+ state_dict = checkpoint['model']
316
+ else:
317
+ state_dict = checkpoint
318
+ # strip prefix of state_dict
319
+ if list(state_dict.keys())[0].startswith('module.'):
320
+ state_dict = {k[7:]: v for k, v in state_dict.items()}
321
+
322
+ # for MoBY, load model of online branch
323
+ if sorted(list(state_dict.keys()))[0].startswith('encoder'):
324
+ state_dict = {k.replace('encoder.', ''): v for k, v in state_dict.items() if k.startswith('encoder.')}
325
+
326
+ # reshape absolute position embedding
327
+ if state_dict.get('absolute_pos_embed') is not None:
328
+ absolute_pos_embed = state_dict['absolute_pos_embed']
329
+ N1, L, C1 = absolute_pos_embed.size()
330
+ N2, C2, H, W = model.absolute_pos_embed.size()
331
+ if N1 != N2 or C1 != C2 or L != H*W:
332
+ logger.warning("Error in loading absolute_pos_embed, pass")
333
+ else:
334
+ state_dict['absolute_pos_embed'] = absolute_pos_embed.view(N2, H, W, C2).permute(0, 3, 1, 2)
335
+
336
+ # interpolate position bias table if needed
337
+ relative_position_bias_table_keys = [k for k in state_dict.keys() if "relative_position_bias_table" in k]
338
+ for table_key in relative_position_bias_table_keys:
339
+ table_pretrained = state_dict[table_key]
340
+ table_current = model.state_dict()[table_key]
341
+ L1, nH1 = table_pretrained.size()
342
+ L2, nH2 = table_current.size()
343
+ if nH1 != nH2:
344
+ logger.warning(f"Error in loading {table_key}, pass")
345
+ else:
346
+ if L1 != L2:
347
+ S1 = int(L1 ** 0.5)
348
+ S2 = int(L2 ** 0.5)
349
+ table_pretrained_resized = F.interpolate(
350
+ table_pretrained.permute(1, 0).view(1, nH1, S1, S1),
351
+ size=(S2, S2), mode='bicubic')
352
+ state_dict[table_key] = table_pretrained_resized.view(nH2, L2).permute(1, 0)
353
+
354
+ # load state_dict
355
+ load_state_dict(model, state_dict, strict, logger)
356
+ return checkpoint
357
+
358
+
359
+ def weights_to_cpu(state_dict):
360
+ """Copy a model state_dict to cpu.
361
+
362
+ Args:
363
+ state_dict (OrderedDict): Model weights on GPU.
364
+
365
+ Returns:
366
+ OrderedDict: Model weights on GPU.
367
+ """
368
+ state_dict_cpu = OrderedDict()
369
+ for key, val in state_dict.items():
370
+ state_dict_cpu[key] = val.cpu()
371
+ return state_dict_cpu
372
+
373
+
374
+ def _save_to_state_dict(module, destination, prefix, keep_vars):
375
+ """Saves module state to `destination` dictionary.
376
+
377
+ This method is modified from :meth:`torch.nn.Module._save_to_state_dict`.
378
+
379
+ Args:
380
+ module (nn.Module): The module to generate state_dict.
381
+ destination (dict): A dict where state will be stored.
382
+ prefix (str): The prefix for parameters and buffers used in this
383
+ module.
384
+ """
385
+ for name, param in module._parameters.items():
386
+ if param is not None:
387
+ destination[prefix + name] = param if keep_vars else param.detach()
388
+ for name, buf in module._buffers.items():
389
+ # remove check of _non_persistent_buffers_set to allow nn.BatchNorm2d
390
+ if buf is not None:
391
+ destination[prefix + name] = buf if keep_vars else buf.detach()
392
+
393
+
394
+ def get_state_dict(module, destination=None, prefix='', keep_vars=False):
395
+ """Returns a dictionary containing a whole state of the module.
396
+
397
+ Both parameters and persistent buffers (e.g. running averages) are
398
+ included. Keys are corresponding parameter and buffer names.
399
+
400
+ This method is modified from :meth:`torch.nn.Module.state_dict` to
401
+ recursively check parallel module in case that the model has a complicated
402
+ structure, e.g., nn.Module(nn.Module(DDP)).
403
+
404
+ Args:
405
+ module (nn.Module): The module to generate state_dict.
406
+ destination (OrderedDict): Returned dict for the state of the
407
+ module.
408
+ prefix (str): Prefix of the key.
409
+ keep_vars (bool): Whether to keep the variable property of the
410
+ parameters. Default: False.
411
+
412
+ Returns:
413
+ dict: A dictionary containing a whole state of the module.
414
+ """
415
+ # recursively check parallel module in case that the model has a
416
+ # complicated structure, e.g., nn.Module(nn.Module(DDP))
417
+ if is_module_wrapper(module):
418
+ module = module.module
419
+
420
+ # below is the same as torch.nn.Module.state_dict()
421
+ if destination is None:
422
+ destination = OrderedDict()
423
+ destination._metadata = OrderedDict()
424
+ destination._metadata[prefix[:-1]] = local_metadata = dict(
425
+ version=module._version)
426
+ _save_to_state_dict(module, destination, prefix, keep_vars)
427
+ for name, child in module._modules.items():
428
+ if child is not None:
429
+ get_state_dict(
430
+ child, destination, prefix + name + '.', keep_vars=keep_vars)
431
+ for hook in module._state_dict_hooks.values():
432
+ hook_result = hook(module, destination, prefix, local_metadata)
433
+ if hook_result is not None:
434
+ destination = hook_result
435
+ return destination
436
+
437
+
438
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
439
+ """Save checkpoint to file.
440
+
441
+ The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
442
+ ``optimizer``. By default ``meta`` will contain version and time info.
443
+
444
+ Args:
445
+ model (Module): Module whose params are to be saved.
446
+ filename (str): Checkpoint filename.
447
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
448
+ meta (dict, optional): Metadata to be saved in checkpoint.
449
+ """
450
+ if meta is None:
451
+ meta = {}
452
+ elif not isinstance(meta, dict):
453
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
454
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
455
+
456
+ if is_module_wrapper(model):
457
+ model = model.module
458
+
459
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
460
+ # save class name to the meta
461
+ meta.update(CLASSES=model.CLASSES)
462
+
463
+ checkpoint = {
464
+ 'meta': meta,
465
+ 'state_dict': weights_to_cpu(get_state_dict(model))
466
+ }
467
+ # save optimizer state dict in the checkpoint
468
+ if isinstance(optimizer, Optimizer):
469
+ checkpoint['optimizer'] = optimizer.state_dict()
470
+ elif isinstance(optimizer, dict):
471
+ checkpoint['optimizer'] = {}
472
+ for name, optim in optimizer.items():
473
+ checkpoint['optimizer'][name] = optim.state_dict()
474
+
475
+ if filename.startswith('pavi://'):
476
+ try:
477
+ from pavi import modelcloud
478
+ from pavi.exception import NodeNotFoundError
479
+ except ImportError:
480
+ raise ImportError(
481
+ 'Please install pavi to load checkpoint from modelcloud.')
482
+ model_path = filename[7:]
483
+ root = modelcloud.Folder()
484
+ model_dir, model_name = osp.split(model_path)
485
+ try:
486
+ model = modelcloud.get(model_dir)
487
+ except NodeNotFoundError:
488
+ model = root.create_training_model(model_dir)
489
+ with TemporaryDirectory() as tmp_dir:
490
+ checkpoint_file = osp.join(tmp_dir, model_name)
491
+ with open(checkpoint_file, 'wb') as f:
492
+ torch.save(checkpoint, f)
493
+ f.flush()
494
+ model.create_file(checkpoint_file, name=model_name)
495
+ else:
496
+ mmcv.mkdir_or_exist(osp.dirname(filename))
497
+ # immediately flush buffer
498
+ with open(filename, 'wb') as f:
499
+ torch.save(checkpoint, f)
500
+ f.flush()
mmcv_custom/runner/__init__.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ from .checkpoint import save_checkpoint
3
+ from .epoch_based_runner import EpochBasedRunnerAmp
4
+
5
+
6
+ __all__ = [
7
+ 'EpochBasedRunnerAmp', 'save_checkpoint'
8
+ ]
mmcv_custom/runner/checkpoint.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import time
4
+ from tempfile import TemporaryDirectory
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.parallel import is_module_wrapper
11
+ from mmcv.runner.checkpoint import weights_to_cpu, get_state_dict
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ def save_checkpoint(model, filename, optimizer=None, meta=None):
20
+ """Save checkpoint to file.
21
+
22
+ The checkpoint will have 4 fields: ``meta``, ``state_dict`` and
23
+ ``optimizer``, ``amp``. By default ``meta`` will contain version
24
+ and time info.
25
+
26
+ Args:
27
+ model (Module): Module whose params are to be saved.
28
+ filename (str): Checkpoint filename.
29
+ optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
30
+ meta (dict, optional): Metadata to be saved in checkpoint.
31
+ """
32
+ if meta is None:
33
+ meta = {}
34
+ elif not isinstance(meta, dict):
35
+ raise TypeError(f'meta must be a dict or None, but got {type(meta)}')
36
+ meta.update(mmcv_version=mmcv.__version__, time=time.asctime())
37
+
38
+ if is_module_wrapper(model):
39
+ model = model.module
40
+
41
+ if hasattr(model, 'CLASSES') and model.CLASSES is not None:
42
+ # save class name to the meta
43
+ meta.update(CLASSES=model.CLASSES)
44
+
45
+ checkpoint = {
46
+ 'meta': meta,
47
+ 'state_dict': weights_to_cpu(get_state_dict(model))
48
+ }
49
+ # save optimizer state dict in the checkpoint
50
+ if isinstance(optimizer, Optimizer):
51
+ checkpoint['optimizer'] = optimizer.state_dict()
52
+ elif isinstance(optimizer, dict):
53
+ checkpoint['optimizer'] = {}
54
+ for name, optim in optimizer.items():
55
+ checkpoint['optimizer'][name] = optim.state_dict()
56
+
57
+ # save amp state dict in the checkpoint
58
+ checkpoint['amp'] = apex.amp.state_dict()
59
+
60
+ if filename.startswith('pavi://'):
61
+ try:
62
+ from pavi import modelcloud
63
+ from pavi.exception import NodeNotFoundError
64
+ except ImportError:
65
+ raise ImportError(
66
+ 'Please install pavi to load checkpoint from modelcloud.')
67
+ model_path = filename[7:]
68
+ root = modelcloud.Folder()
69
+ model_dir, model_name = osp.split(model_path)
70
+ try:
71
+ model = modelcloud.get(model_dir)
72
+ except NodeNotFoundError:
73
+ model = root.create_training_model(model_dir)
74
+ with TemporaryDirectory() as tmp_dir:
75
+ checkpoint_file = osp.join(tmp_dir, model_name)
76
+ with open(checkpoint_file, 'wb') as f:
77
+ torch.save(checkpoint, f)
78
+ f.flush()
79
+ model.create_file(checkpoint_file, name=model_name)
80
+ else:
81
+ mmcv.mkdir_or_exist(osp.dirname(filename))
82
+ # immediately flush buffer
83
+ with open(filename, 'wb') as f:
84
+ torch.save(checkpoint, f)
85
+ f.flush()
mmcv_custom/runner/epoch_based_runner.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Open-MMLab. All rights reserved.
2
+ import os.path as osp
3
+ import platform
4
+ import shutil
5
+
6
+ import torch
7
+ from torch.optim import Optimizer
8
+
9
+ import mmcv
10
+ from mmcv.runner import RUNNERS, EpochBasedRunner
11
+ from .checkpoint import save_checkpoint
12
+
13
+ try:
14
+ import apex
15
+ except:
16
+ print('apex is not installed')
17
+
18
+
19
+ @RUNNERS.register_module()
20
+ class EpochBasedRunnerAmp(EpochBasedRunner):
21
+ """Epoch-based Runner with AMP support.
22
+
23
+ This runner train models epoch by epoch.
24
+ """
25
+
26
+ def save_checkpoint(self,
27
+ out_dir,
28
+ filename_tmpl='epoch_{}.pth',
29
+ save_optimizer=True,
30
+ meta=None,
31
+ create_symlink=True):
32
+ """Save the checkpoint.
33
+
34
+ Args:
35
+ out_dir (str): The directory that checkpoints are saved.
36
+ filename_tmpl (str, optional): The checkpoint filename template,
37
+ which contains a placeholder for the epoch number.
38
+ Defaults to 'epoch_{}.pth'.
39
+ save_optimizer (bool, optional): Whether to save the optimizer to
40
+ the checkpoint. Defaults to True.
41
+ meta (dict, optional): The meta information to be saved in the
42
+ checkpoint. Defaults to None.
43
+ create_symlink (bool, optional): Whether to create a symlink
44
+ "latest.pth" to point to the latest checkpoint.
45
+ Defaults to True.
46
+ """
47
+ if meta is None:
48
+ meta = dict(epoch=self.epoch + 1, iter=self.iter)
49
+ elif isinstance(meta, dict):
50
+ meta.update(epoch=self.epoch + 1, iter=self.iter)
51
+ else:
52
+ raise TypeError(
53
+ f'meta should be a dict or None, but got {type(meta)}')
54
+ if self.meta is not None:
55
+ meta.update(self.meta)
56
+
57
+ filename = filename_tmpl.format(self.epoch + 1)
58
+ filepath = osp.join(out_dir, filename)
59
+ optimizer = self.optimizer if save_optimizer else None
60
+ save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta)
61
+ # in some environments, `os.symlink` is not supported, you may need to
62
+ # set `create_symlink` to False
63
+ if create_symlink:
64
+ dst_file = osp.join(out_dir, 'latest.pth')
65
+ if platform.system() != 'Windows':
66
+ mmcv.symlink(filename, dst_file)
67
+ else:
68
+ shutil.copy(filepath, dst_file)
69
+
70
+ def resume(self,
71
+ checkpoint,
72
+ resume_optimizer=True,
73
+ map_location='default'):
74
+ if map_location == 'default':
75
+ if torch.cuda.is_available():
76
+ device_id = torch.cuda.current_device()
77
+ checkpoint = self.load_checkpoint(
78
+ checkpoint,
79
+ map_location=lambda storage, loc: storage.cuda(device_id))
80
+ else:
81
+ checkpoint = self.load_checkpoint(checkpoint)
82
+ else:
83
+ checkpoint = self.load_checkpoint(
84
+ checkpoint, map_location=map_location)
85
+
86
+ self._epoch = checkpoint['meta']['epoch']
87
+ self._iter = checkpoint['meta']['iter']
88
+ if 'optimizer' in checkpoint and resume_optimizer:
89
+ if isinstance(self.optimizer, Optimizer):
90
+ self.optimizer.load_state_dict(checkpoint['optimizer'])
91
+ elif isinstance(self.optimizer, dict):
92
+ for k in self.optimizer.keys():
93
+ self.optimizer[k].load_state_dict(
94
+ checkpoint['optimizer'][k])
95
+ else:
96
+ raise TypeError(
97
+ 'Optimizer should be dict or torch.optim.Optimizer '
98
+ f'but got {type(self.optimizer)}')
99
+
100
+ if 'amp' in checkpoint:
101
+ apex.amp.load_state_dict(checkpoint['amp'])
102
+ self.logger.info('load amp state dict')
103
+
104
+ self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter)
mmdet/__init__.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+
3
+ from .version import __version__, short_version
4
+
5
+
6
+ def digit_version(version_str):
7
+ digit_version = []
8
+ for x in version_str.split('.'):
9
+ if x.isdigit():
10
+ digit_version.append(int(x))
11
+ elif x.find('rc') != -1:
12
+ patch_version = x.split('rc')
13
+ digit_version.append(int(patch_version[0]) - 1)
14
+ digit_version.append(int(patch_version[1]))
15
+ return digit_version
16
+
17
+
18
+ mmcv_minimum_version = '1.2.4'
19
+ mmcv_maximum_version = '1.4.0'
20
+ mmcv_version = digit_version(mmcv.__version__)
21
+
22
+
23
+ assert (mmcv_version >= digit_version(mmcv_minimum_version)
24
+ and mmcv_version <= digit_version(mmcv_maximum_version)), \
25
+ f'MMCV=={mmcv.__version__} is used but incompatible. ' \
26
+ f'Please install mmcv>={mmcv_minimum_version}, <={mmcv_maximum_version}.'
27
+
28
+ __all__ = ['__version__', 'short_version']
mmdet/apis/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
1
+ from .inference import (async_inference_detector, inference_detector,
2
+ init_detector, show_result_pyplot)
3
+ from .test import multi_gpu_test, single_gpu_test
4
+ from .train import get_root_logger, set_random_seed, train_detector
5
+
6
+ __all__ = [
7
+ 'get_root_logger', 'set_random_seed', 'train_detector', 'init_detector',
8
+ 'async_inference_detector', 'inference_detector', 'show_result_pyplot',
9
+ 'multi_gpu_test', 'single_gpu_test'
10
+ ]
mmdet/apis/inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+
3
+ import mmcv
4
+ import numpy as np
5
+ import torch
6
+ from mmcv.ops import RoIPool
7
+ from mmcv.parallel import collate, scatter
8
+ from mmcv.runner import load_checkpoint
9
+
10
+ from mmdet.core import get_classes
11
+ from mmdet.datasets import replace_ImageToTensor
12
+ from mmdet.datasets.pipelines import Compose
13
+ from mmdet.models import build_detector
14
+
15
+
16
+ def init_detector(config, checkpoint=None, device='cuda:0', cfg_options=None):
17
+ """Initialize a detector from config file.
18
+
19
+ Args:
20
+ config (str or :obj:`mmcv.Config`): Config file path or the config
21
+ object.
22
+ checkpoint (str, optional): Checkpoint path. If left as None, the model
23
+ will not load any weights.
24
+ cfg_options (dict): Options to override some settings in the used
25
+ config.
26
+
27
+ Returns:
28
+ nn.Module: The constructed detector.
29
+ """
30
+ if isinstance(config, str):
31
+ config = mmcv.Config.fromfile(config)
32
+ elif not isinstance(config, mmcv.Config):
33
+ raise TypeError('config must be a filename or Config object, '
34
+ f'but got {type(config)}')
35
+ if cfg_options is not None:
36
+ config.merge_from_dict(cfg_options)
37
+ config.model.pretrained = None
38
+ config.model.train_cfg = None
39
+ model = build_detector(config.model, test_cfg=config.get('test_cfg'))
40
+ if checkpoint is not None:
41
+ map_loc = 'cpu' if device == 'cpu' else None
42
+ checkpoint = load_checkpoint(model, checkpoint, map_location=map_loc)
43
+ if 'CLASSES' in checkpoint.get('meta', {}):
44
+ model.CLASSES = checkpoint['meta']['CLASSES']
45
+ else:
46
+ warnings.simplefilter('once')
47
+ warnings.warn('Class names are not saved in the checkpoint\'s '
48
+ 'meta data, use COCO classes by default.')
49
+ model.CLASSES = get_classes('coco')
50
+ model.cfg = config # save the config in the model for convenience
51
+ model.to(device)
52
+ model.eval()
53
+ return model
54
+
55
+
56
+ class LoadImage(object):
57
+ """Deprecated.
58
+
59
+ A simple pipeline to load image.
60
+ """
61
+
62
+ def __call__(self, results):
63
+ """Call function to load images into results.
64
+
65
+ Args:
66
+ results (dict): A result dict contains the file name
67
+ of the image to be read.
68
+ Returns:
69
+ dict: ``results`` will be returned containing loaded image.
70
+ """
71
+ warnings.simplefilter('once')
72
+ warnings.warn('`LoadImage` is deprecated and will be removed in '
73
+ 'future releases. You may use `LoadImageFromWebcam` '
74
+ 'from `mmdet.datasets.pipelines.` instead.')
75
+ if isinstance(results['img'], str):
76
+ results['filename'] = results['img']
77
+ results['ori_filename'] = results['img']
78
+ else:
79
+ results['filename'] = None
80
+ results['ori_filename'] = None
81
+ img = mmcv.imread(results['img'])
82
+ results['img'] = img
83
+ results['img_fields'] = ['img']
84
+ results['img_shape'] = img.shape
85
+ results['ori_shape'] = img.shape
86
+ return results
87
+
88
+
89
+ def inference_detector(model, imgs):
90
+ """Inference image(s) with the detector.
91
+
92
+ Args:
93
+ model (nn.Module): The loaded detector.
94
+ imgs (str/ndarray or list[str/ndarray] or tuple[str/ndarray]):
95
+ Either image files or loaded images.
96
+
97
+ Returns:
98
+ If imgs is a list or tuple, the same length list type results
99
+ will be returned, otherwise return the detection results directly.
100
+ """
101
+
102
+ if isinstance(imgs, (list, tuple)):
103
+ is_batch = True
104
+ else:
105
+ imgs = [imgs]
106
+ is_batch = False
107
+
108
+ cfg = model.cfg
109
+ device = next(model.parameters()).device # model device
110
+
111
+ if isinstance(imgs[0], np.ndarray):
112
+ cfg = cfg.copy()
113
+ # set loading pipeline type
114
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
115
+
116
+ cfg.data.test.pipeline = replace_ImageToTensor(cfg.data.test.pipeline)
117
+ test_pipeline = Compose(cfg.data.test.pipeline)
118
+
119
+ datas = []
120
+ for img in imgs:
121
+ # prepare data
122
+ if isinstance(img, np.ndarray):
123
+ # directly add img
124
+ data = dict(img=img)
125
+ else:
126
+ # add information into dict
127
+ data = dict(img_info=dict(filename=img), img_prefix=None)
128
+ # build the data pipeline
129
+ data = test_pipeline(data)
130
+ datas.append(data)
131
+
132
+ data = collate(datas, samples_per_gpu=len(imgs))
133
+ # just get the actual data from DataContainer
134
+ data['img_metas'] = [img_metas.data[0] for img_metas in data['img_metas']]
135
+ data['img'] = [img.data[0] for img in data['img']]
136
+ if next(model.parameters()).is_cuda:
137
+ # scatter to specified GPU
138
+ data = scatter(data, [device])[0]
139
+ else:
140
+ for m in model.modules():
141
+ assert not isinstance(
142
+ m, RoIPool
143
+ ), 'CPU inference with RoIPool is not supported currently.'
144
+
145
+ # forward the model
146
+ with torch.no_grad():
147
+ results = model(return_loss=False, rescale=True, **data)
148
+
149
+ if not is_batch:
150
+ return results[0]
151
+ else:
152
+ return results
153
+
154
+
155
+ async def async_inference_detector(model, img):
156
+ """Async inference image(s) with the detector.
157
+
158
+ Args:
159
+ model (nn.Module): The loaded detector.
160
+ img (str | ndarray): Either image files or loaded images.
161
+
162
+ Returns:
163
+ Awaitable detection results.
164
+ """
165
+ cfg = model.cfg
166
+ device = next(model.parameters()).device # model device
167
+ # prepare data
168
+ if isinstance(img, np.ndarray):
169
+ # directly add img
170
+ data = dict(img=img)
171
+ cfg = cfg.copy()
172
+ # set loading pipeline type
173
+ cfg.data.test.pipeline[0].type = 'LoadImageFromWebcam'
174
+ else:
175
+ # add information into dict
176
+ data = dict(img_info=dict(filename=img), img_prefix=None)
177
+ # build the data pipeline
178
+ test_pipeline = Compose(cfg.data.test.pipeline)
179
+ data = test_pipeline(data)
180
+ data = scatter(collate([data], samples_per_gpu=1), [device])[0]
181
+
182
+ # We don't restore `torch.is_grad_enabled()` value during concurrent
183
+ # inference since execution can overlap
184
+ torch.set_grad_enabled(False)
185
+ result = await model.aforward_test(rescale=True, **data)
186
+ return result
187
+
188
+
189
+ def show_result_pyplot(model,
190
+ img,
191
+ result,
192
+ score_thr=0.3,
193
+ title='result',
194
+ wait_time=0):
195
+ """Visualize the detection results on the image.
196
+
197
+ Args:
198
+ model (nn.Module): The loaded detector.
199
+ img (str or np.ndarray): Image filename or loaded image.
200
+ result (tuple[list] or list): The detection result, can be either
201
+ (bbox, segm) or just bbox.
202
+ score_thr (float): The threshold to visualize the bboxes and masks.
203
+ title (str): Title of the pyplot figure.
204
+ wait_time (float): Value of waitKey param.
205
+ Default: 0.
206
+ """
207
+ if hasattr(model, 'module'):
208
+ model = model.module
209
+ model.show_result(
210
+ img,
211
+ result,
212
+ score_thr=score_thr,
213
+ show=True,
214
+ wait_time=wait_time,
215
+ win_name=title,
216
+ bbox_color=(72, 101, 241),
217
+ text_color=(72, 101, 241))
mmdet/apis/test.py ADDED
@@ -0,0 +1,189 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import pickle
3
+ import shutil
4
+ import tempfile
5
+ import time
6
+
7
+ import mmcv
8
+ import torch
9
+ import torch.distributed as dist
10
+ from mmcv.image import tensor2imgs
11
+ from mmcv.runner import get_dist_info
12
+
13
+ from mmdet.core import encode_mask_results
14
+
15
+
16
+ def single_gpu_test(model,
17
+ data_loader,
18
+ show=False,
19
+ out_dir=None,
20
+ show_score_thr=0.3):
21
+ model.eval()
22
+ results = []
23
+ dataset = data_loader.dataset
24
+ prog_bar = mmcv.ProgressBar(len(dataset))
25
+ for i, data in enumerate(data_loader):
26
+ with torch.no_grad():
27
+ result = model(return_loss=False, rescale=True, **data)
28
+
29
+ batch_size = len(result)
30
+ if show or out_dir:
31
+ if batch_size == 1 and isinstance(data['img'][0], torch.Tensor):
32
+ img_tensor = data['img'][0]
33
+ else:
34
+ img_tensor = data['img'][0].data[0]
35
+ img_metas = data['img_metas'][0].data[0]
36
+ imgs = tensor2imgs(img_tensor, **img_metas[0]['img_norm_cfg'])
37
+ assert len(imgs) == len(img_metas)
38
+
39
+ for i, (img, img_meta) in enumerate(zip(imgs, img_metas)):
40
+ h, w, _ = img_meta['img_shape']
41
+ img_show = img[:h, :w, :]
42
+
43
+ ori_h, ori_w = img_meta['ori_shape'][:-1]
44
+ img_show = mmcv.imresize(img_show, (ori_w, ori_h))
45
+
46
+ if out_dir:
47
+ out_file = osp.join(out_dir, img_meta['ori_filename'])
48
+ else:
49
+ out_file = None
50
+ model.module.show_result(
51
+ img_show,
52
+ result[i],
53
+ show=show,
54
+ out_file=out_file,
55
+ score_thr=show_score_thr)
56
+
57
+ # encode mask results
58
+ if isinstance(result[0], tuple):
59
+ result = [(bbox_results, encode_mask_results(mask_results))
60
+ for bbox_results, mask_results in result]
61
+ results.extend(result)
62
+
63
+ for _ in range(batch_size):
64
+ prog_bar.update()
65
+ return results
66
+
67
+
68
+ def multi_gpu_test(model, data_loader, tmpdir=None, gpu_collect=False):
69
+ """Test model with multiple gpus.
70
+
71
+ This method tests model with multiple gpus and collects the results
72
+ under two different modes: gpu and cpu modes. By setting 'gpu_collect=True'
73
+ it encodes results to gpu tensors and use gpu communication for results
74
+ collection. On cpu mode it saves the results on different gpus to 'tmpdir'
75
+ and collects them by the rank 0 worker.
76
+
77
+ Args:
78
+ model (nn.Module): Model to be tested.
79
+ data_loader (nn.Dataloader): Pytorch data loader.
80
+ tmpdir (str): Path of directory to save the temporary results from
81
+ different gpus under cpu mode.
82
+ gpu_collect (bool): Option to use either gpu or cpu to collect results.
83
+
84
+ Returns:
85
+ list: The prediction results.
86
+ """
87
+ model.eval()
88
+ results = []
89
+ dataset = data_loader.dataset
90
+ rank, world_size = get_dist_info()
91
+ if rank == 0:
92
+ prog_bar = mmcv.ProgressBar(len(dataset))
93
+ time.sleep(2) # This line can prevent deadlock problem in some cases.
94
+ for i, data in enumerate(data_loader):
95
+ with torch.no_grad():
96
+ result = model(return_loss=False, rescale=True, **data)
97
+ # encode mask results
98
+ if isinstance(result[0], tuple):
99
+ result = [(bbox_results, encode_mask_results(mask_results))
100
+ for bbox_results, mask_results in result]
101
+ results.extend(result)
102
+
103
+ if rank == 0:
104
+ batch_size = len(result)
105
+ for _ in range(batch_size * world_size):
106
+ prog_bar.update()
107
+
108
+ # collect results from all ranks
109
+ if gpu_collect:
110
+ results = collect_results_gpu(results, len(dataset))
111
+ else:
112
+ results = collect_results_cpu(results, len(dataset), tmpdir)
113
+ return results
114
+
115
+
116
+ def collect_results_cpu(result_part, size, tmpdir=None):
117
+ rank, world_size = get_dist_info()
118
+ # create a tmp dir if it is not specified
119
+ if tmpdir is None:
120
+ MAX_LEN = 512
121
+ # 32 is whitespace
122
+ dir_tensor = torch.full((MAX_LEN, ),
123
+ 32,
124
+ dtype=torch.uint8,
125
+ device='cuda')
126
+ if rank == 0:
127
+ mmcv.mkdir_or_exist('.dist_test')
128
+ tmpdir = tempfile.mkdtemp(dir='.dist_test')
129
+ tmpdir = torch.tensor(
130
+ bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
131
+ dir_tensor[:len(tmpdir)] = tmpdir
132
+ dist.broadcast(dir_tensor, 0)
133
+ tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
134
+ else:
135
+ mmcv.mkdir_or_exist(tmpdir)
136
+ # dump the part result to the dir
137
+ mmcv.dump(result_part, osp.join(tmpdir, f'part_{rank}.pkl'))
138
+ dist.barrier()
139
+ # collect all parts
140
+ if rank != 0:
141
+ return None
142
+ else:
143
+ # load results of all parts from tmp dir
144
+ part_list = []
145
+ for i in range(world_size):
146
+ part_file = osp.join(tmpdir, f'part_{i}.pkl')
147
+ part_list.append(mmcv.load(part_file))
148
+ # sort the results
149
+ ordered_results = []
150
+ for res in zip(*part_list):
151
+ ordered_results.extend(list(res))
152
+ # the dataloader may pad some samples
153
+ ordered_results = ordered_results[:size]
154
+ # remove tmp dir
155
+ shutil.rmtree(tmpdir)
156
+ return ordered_results
157
+
158
+
159
+ def collect_results_gpu(result_part, size):
160
+ rank, world_size = get_dist_info()
161
+ # dump result part to tensor with pickle
162
+ part_tensor = torch.tensor(
163
+ bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
164
+ # gather all result part tensor shape
165
+ shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
166
+ shape_list = [shape_tensor.clone() for _ in range(world_size)]
167
+ dist.all_gather(shape_list, shape_tensor)
168
+ # padding result part tensor to max length
169
+ shape_max = torch.tensor(shape_list).max()
170
+ part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
171
+ part_send[:shape_tensor[0]] = part_tensor
172
+ part_recv_list = [
173
+ part_tensor.new_zeros(shape_max) for _ in range(world_size)
174
+ ]
175
+ # gather all result part
176
+ dist.all_gather(part_recv_list, part_send)
177
+
178
+ if rank == 0:
179
+ part_list = []
180
+ for recv, shape in zip(part_recv_list, shape_list):
181
+ part_list.append(
182
+ pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
183
+ # sort the results
184
+ ordered_results = []
185
+ for res in zip(*part_list):
186
+ ordered_results.extend(list(res))
187
+ # the dataloader may pad some samples
188
+ ordered_results = ordered_results[:size]
189
+ return ordered_results
mmdet/apis/train.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import warnings
3
+
4
+ import numpy as np
5
+ import torch
6
+ from mmcv.parallel import MMDataParallel, MMDistributedDataParallel
7
+ from mmcv.runner import (HOOKS, DistSamplerSeedHook, EpochBasedRunner,
8
+ Fp16OptimizerHook, OptimizerHook, build_optimizer,
9
+ build_runner)
10
+ from mmcv.utils import build_from_cfg
11
+
12
+ from mmdet.core import DistEvalHook, EvalHook
13
+ from mmdet.datasets import (build_dataloader, build_dataset,
14
+ replace_ImageToTensor)
15
+ from mmdet.utils import get_root_logger
16
+ from mmcv_custom.runner import EpochBasedRunnerAmp
17
+ try:
18
+ import apex
19
+ except:
20
+ print('apex is not installed')
21
+
22
+
23
+ def set_random_seed(seed, deterministic=False):
24
+ """Set random seed.
25
+
26
+ Args:
27
+ seed (int): Seed to be used.
28
+ deterministic (bool): Whether to set the deterministic option for
29
+ CUDNN backend, i.e., set `torch.backends.cudnn.deterministic`
30
+ to True and `torch.backends.cudnn.benchmark` to False.
31
+ Default: False.
32
+ """
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+ torch.manual_seed(seed)
36
+ torch.cuda.manual_seed_all(seed)
37
+ if deterministic:
38
+ torch.backends.cudnn.deterministic = True
39
+ torch.backends.cudnn.benchmark = False
40
+
41
+
42
+ def train_detector(model,
43
+ dataset,
44
+ cfg,
45
+ distributed=False,
46
+ validate=False,
47
+ timestamp=None,
48
+ meta=None):
49
+ logger = get_root_logger(cfg.log_level)
50
+
51
+ # prepare data loaders
52
+ dataset = dataset if isinstance(dataset, (list, tuple)) else [dataset]
53
+ if 'imgs_per_gpu' in cfg.data:
54
+ logger.warning('"imgs_per_gpu" is deprecated in MMDet V2.0. '
55
+ 'Please use "samples_per_gpu" instead')
56
+ if 'samples_per_gpu' in cfg.data:
57
+ logger.warning(
58
+ f'Got "imgs_per_gpu"={cfg.data.imgs_per_gpu} and '
59
+ f'"samples_per_gpu"={cfg.data.samples_per_gpu}, "imgs_per_gpu"'
60
+ f'={cfg.data.imgs_per_gpu} is used in this experiments')
61
+ else:
62
+ logger.warning(
63
+ 'Automatically set "samples_per_gpu"="imgs_per_gpu"='
64
+ f'{cfg.data.imgs_per_gpu} in this experiments')
65
+ cfg.data.samples_per_gpu = cfg.data.imgs_per_gpu
66
+
67
+ data_loaders = [
68
+ build_dataloader(
69
+ ds,
70
+ cfg.data.samples_per_gpu,
71
+ cfg.data.workers_per_gpu,
72
+ # cfg.gpus will be ignored if distributed
73
+ len(cfg.gpu_ids),
74
+ dist=distributed,
75
+ seed=cfg.seed) for ds in dataset
76
+ ]
77
+
78
+ # build optimizer
79
+ optimizer = build_optimizer(model, cfg.optimizer)
80
+
81
+ # use apex fp16 optimizer
82
+ if cfg.optimizer_config.get("type", None) and cfg.optimizer_config["type"] == "DistOptimizerHook":
83
+ if cfg.optimizer_config.get("use_fp16", False):
84
+ model, optimizer = apex.amp.initialize(
85
+ model.cuda(), optimizer, opt_level="O1")
86
+ for m in model.modules():
87
+ if hasattr(m, "fp16_enabled"):
88
+ m.fp16_enabled = True
89
+
90
+ # put model on gpus
91
+ if distributed:
92
+ find_unused_parameters = cfg.get('find_unused_parameters', False)
93
+ # Sets the `find_unused_parameters` parameter in
94
+ # torch.nn.parallel.DistributedDataParallel
95
+ model = MMDistributedDataParallel(
96
+ model.cuda(),
97
+ device_ids=[torch.cuda.current_device()],
98
+ broadcast_buffers=False,
99
+ find_unused_parameters=find_unused_parameters)
100
+ else:
101
+ model = MMDataParallel(
102
+ model.cuda(cfg.gpu_ids[0]), device_ids=cfg.gpu_ids)
103
+
104
+ if 'runner' not in cfg:
105
+ cfg.runner = {
106
+ 'type': 'EpochBasedRunner',
107
+ 'max_epochs': cfg.total_epochs
108
+ }
109
+ warnings.warn(
110
+ 'config is now expected to have a `runner` section, '
111
+ 'please set `runner` in your config.', UserWarning)
112
+ else:
113
+ if 'total_epochs' in cfg:
114
+ assert cfg.total_epochs == cfg.runner.max_epochs
115
+
116
+ # build runner
117
+ runner = build_runner(
118
+ cfg.runner,
119
+ default_args=dict(
120
+ model=model,
121
+ optimizer=optimizer,
122
+ work_dir=cfg.work_dir,
123
+ logger=logger,
124
+ meta=meta))
125
+
126
+ # an ugly workaround to make .log and .log.json filenames the same
127
+ runner.timestamp = timestamp
128
+
129
+ # fp16 setting
130
+ fp16_cfg = cfg.get('fp16', None)
131
+ if fp16_cfg is not None:
132
+ optimizer_config = Fp16OptimizerHook(
133
+ **cfg.optimizer_config, **fp16_cfg, distributed=distributed)
134
+ elif distributed and 'type' not in cfg.optimizer_config:
135
+ optimizer_config = OptimizerHook(**cfg.optimizer_config)
136
+ else:
137
+ optimizer_config = cfg.optimizer_config
138
+
139
+ # register hooks
140
+ runner.register_training_hooks(cfg.lr_config, optimizer_config,
141
+ cfg.checkpoint_config, cfg.log_config,
142
+ cfg.get('momentum_config', None))
143
+ if distributed:
144
+ if isinstance(runner, EpochBasedRunner):
145
+ runner.register_hook(DistSamplerSeedHook())
146
+
147
+ # register eval hooks
148
+ if validate:
149
+ # Support batch_size > 1 in validation
150
+ val_samples_per_gpu = cfg.data.val.pop('samples_per_gpu', 1)
151
+ if val_samples_per_gpu > 1:
152
+ # Replace 'ImageToTensor' to 'DefaultFormatBundle'
153
+ cfg.data.val.pipeline = replace_ImageToTensor(
154
+ cfg.data.val.pipeline)
155
+ val_dataset = build_dataset(cfg.data.val, dict(test_mode=True))
156
+ val_dataloader = build_dataloader(
157
+ val_dataset,
158
+ samples_per_gpu=val_samples_per_gpu,
159
+ workers_per_gpu=cfg.data.workers_per_gpu,
160
+ dist=distributed,
161
+ shuffle=False)
162
+ eval_cfg = cfg.get('evaluation', {})
163
+ eval_cfg['by_epoch'] = cfg.runner['type'] != 'IterBasedRunner'
164
+ eval_hook = DistEvalHook if distributed else EvalHook
165
+ runner.register_hook(eval_hook(val_dataloader, **eval_cfg))
166
+
167
+ # user-defined hooks
168
+ if cfg.get('custom_hooks', None):
169
+ custom_hooks = cfg.custom_hooks
170
+ assert isinstance(custom_hooks, list), \
171
+ f'custom_hooks expect list type, but got {type(custom_hooks)}'
172
+ for hook_cfg in cfg.custom_hooks:
173
+ assert isinstance(hook_cfg, dict), \
174
+ 'Each item in custom_hooks expects dict type, but got ' \
175
+ f'{type(hook_cfg)}'
176
+ hook_cfg = hook_cfg.copy()
177
+ priority = hook_cfg.pop('priority', 'NORMAL')
178
+ hook = build_from_cfg(hook_cfg, HOOKS)
179
+ runner.register_hook(hook, priority=priority)
180
+
181
+ if cfg.resume_from:
182
+ runner.resume(cfg.resume_from)
183
+ elif cfg.load_from:
184
+ runner.load_checkpoint(cfg.load_from)
185
+ runner.run(data_loaders, cfg.workflow)
mmdet/core/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ from .anchor import * # noqa: F401, F403
2
+ from .bbox import * # noqa: F401, F403
3
+ from .evaluation import * # noqa: F401, F403
4
+ from .export import * # noqa: F401, F403
5
+ from .mask import * # noqa: F401, F403
6
+ from .post_processing import * # noqa: F401, F403
7
+ from .utils import * # noqa: F401, F403
mmdet/core/anchor/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from .anchor_generator import (AnchorGenerator, LegacyAnchorGenerator,
2
+ YOLOAnchorGenerator)
3
+ from .builder import ANCHOR_GENERATORS, build_anchor_generator
4
+ from .point_generator import PointGenerator
5
+ from .utils import anchor_inside_flags, calc_region, images_to_levels
6
+
7
+ __all__ = [
8
+ 'AnchorGenerator', 'LegacyAnchorGenerator', 'anchor_inside_flags',
9
+ 'PointGenerator', 'images_to_levels', 'calc_region',
10
+ 'build_anchor_generator', 'ANCHOR_GENERATORS', 'YOLOAnchorGenerator'
11
+ ]
mmdet/core/anchor/anchor_generator.py ADDED
@@ -0,0 +1,727 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import mmcv
2
+ import numpy as np
3
+ import torch
4
+ from torch.nn.modules.utils import _pair
5
+
6
+ from .builder import ANCHOR_GENERATORS
7
+
8
+
9
+ @ANCHOR_GENERATORS.register_module()
10
+ class AnchorGenerator(object):
11
+ """Standard anchor generator for 2D anchor-based detectors.
12
+
13
+ Args:
14
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
15
+ in multiple feature levels in order (w, h).
16
+ ratios (list[float]): The list of ratios between the height and width
17
+ of anchors in a single level.
18
+ scales (list[int] | None): Anchor scales for anchors in a single level.
19
+ It cannot be set at the same time if `octave_base_scale` and
20
+ `scales_per_octave` are set.
21
+ base_sizes (list[int] | None): The basic sizes
22
+ of anchors in multiple levels.
23
+ If None is given, strides will be used as base_sizes.
24
+ (If strides are non square, the shortest stride is taken.)
25
+ scale_major (bool): Whether to multiply scales first when generating
26
+ base anchors. If true, the anchors in the same row will have the
27
+ same scales. By default it is True in V2.0
28
+ octave_base_scale (int): The base scale of octave.
29
+ scales_per_octave (int): Number of scales for each octave.
30
+ `octave_base_scale` and `scales_per_octave` are usually used in
31
+ retinanet and the `scales` should be None when they are set.
32
+ centers (list[tuple[float, float]] | None): The centers of the anchor
33
+ relative to the feature grid center in multiple feature levels.
34
+ By default it is set to be None and not used. If a list of tuple of
35
+ float is given, they will be used to shift the centers of anchors.
36
+ center_offset (float): The offset of center in proportion to anchors'
37
+ width and height. By default it is 0 in V2.0.
38
+
39
+ Examples:
40
+ >>> from mmdet.core import AnchorGenerator
41
+ >>> self = AnchorGenerator([16], [1.], [1.], [9])
42
+ >>> all_anchors = self.grid_anchors([(2, 2)], device='cpu')
43
+ >>> print(all_anchors)
44
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
45
+ [11.5000, -4.5000, 20.5000, 4.5000],
46
+ [-4.5000, 11.5000, 4.5000, 20.5000],
47
+ [11.5000, 11.5000, 20.5000, 20.5000]])]
48
+ >>> self = AnchorGenerator([16, 32], [1.], [1.], [9, 18])
49
+ >>> all_anchors = self.grid_anchors([(2, 2), (1, 1)], device='cpu')
50
+ >>> print(all_anchors)
51
+ [tensor([[-4.5000, -4.5000, 4.5000, 4.5000],
52
+ [11.5000, -4.5000, 20.5000, 4.5000],
53
+ [-4.5000, 11.5000, 4.5000, 20.5000],
54
+ [11.5000, 11.5000, 20.5000, 20.5000]]), \
55
+ tensor([[-9., -9., 9., 9.]])]
56
+ """
57
+
58
+ def __init__(self,
59
+ strides,
60
+ ratios,
61
+ scales=None,
62
+ base_sizes=None,
63
+ scale_major=True,
64
+ octave_base_scale=None,
65
+ scales_per_octave=None,
66
+ centers=None,
67
+ center_offset=0.):
68
+ # check center and center_offset
69
+ if center_offset != 0:
70
+ assert centers is None, 'center cannot be set when center_offset' \
71
+ f'!=0, {centers} is given.'
72
+ if not (0 <= center_offset <= 1):
73
+ raise ValueError('center_offset should be in range [0, 1], '
74
+ f'{center_offset} is given.')
75
+ if centers is not None:
76
+ assert len(centers) == len(strides), \
77
+ 'The number of strides should be the same as centers, got ' \
78
+ f'{strides} and {centers}'
79
+
80
+ # calculate base sizes of anchors
81
+ self.strides = [_pair(stride) for stride in strides]
82
+ self.base_sizes = [min(stride) for stride in self.strides
83
+ ] if base_sizes is None else base_sizes
84
+ assert len(self.base_sizes) == len(self.strides), \
85
+ 'The number of strides should be the same as base sizes, got ' \
86
+ f'{self.strides} and {self.base_sizes}'
87
+
88
+ # calculate scales of anchors
89
+ assert ((octave_base_scale is not None
90
+ and scales_per_octave is not None) ^ (scales is not None)), \
91
+ 'scales and octave_base_scale with scales_per_octave cannot' \
92
+ ' be set at the same time'
93
+ if scales is not None:
94
+ self.scales = torch.Tensor(scales)
95
+ elif octave_base_scale is not None and scales_per_octave is not None:
96
+ octave_scales = np.array(
97
+ [2**(i / scales_per_octave) for i in range(scales_per_octave)])
98
+ scales = octave_scales * octave_base_scale
99
+ self.scales = torch.Tensor(scales)
100
+ else:
101
+ raise ValueError('Either scales or octave_base_scale with '
102
+ 'scales_per_octave should be set')
103
+
104
+ self.octave_base_scale = octave_base_scale
105
+ self.scales_per_octave = scales_per_octave
106
+ self.ratios = torch.Tensor(ratios)
107
+ self.scale_major = scale_major
108
+ self.centers = centers
109
+ self.center_offset = center_offset
110
+ self.base_anchors = self.gen_base_anchors()
111
+
112
+ @property
113
+ def num_base_anchors(self):
114
+ """list[int]: total number of base anchors in a feature grid"""
115
+ return [base_anchors.size(0) for base_anchors in self.base_anchors]
116
+
117
+ @property
118
+ def num_levels(self):
119
+ """int: number of feature levels that the generator will be applied"""
120
+ return len(self.strides)
121
+
122
+ def gen_base_anchors(self):
123
+ """Generate base anchors.
124
+
125
+ Returns:
126
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
127
+ feature levels.
128
+ """
129
+ multi_level_base_anchors = []
130
+ for i, base_size in enumerate(self.base_sizes):
131
+ center = None
132
+ if self.centers is not None:
133
+ center = self.centers[i]
134
+ multi_level_base_anchors.append(
135
+ self.gen_single_level_base_anchors(
136
+ base_size,
137
+ scales=self.scales,
138
+ ratios=self.ratios,
139
+ center=center))
140
+ return multi_level_base_anchors
141
+
142
+ def gen_single_level_base_anchors(self,
143
+ base_size,
144
+ scales,
145
+ ratios,
146
+ center=None):
147
+ """Generate base anchors of a single level.
148
+
149
+ Args:
150
+ base_size (int | float): Basic size of an anchor.
151
+ scales (torch.Tensor): Scales of the anchor.
152
+ ratios (torch.Tensor): The ratio between between the height
153
+ and width of anchors in a single level.
154
+ center (tuple[float], optional): The center of the base anchor
155
+ related to a single feature grid. Defaults to None.
156
+
157
+ Returns:
158
+ torch.Tensor: Anchors in a single-level feature maps.
159
+ """
160
+ w = base_size
161
+ h = base_size
162
+ if center is None:
163
+ x_center = self.center_offset * w
164
+ y_center = self.center_offset * h
165
+ else:
166
+ x_center, y_center = center
167
+
168
+ h_ratios = torch.sqrt(ratios)
169
+ w_ratios = 1 / h_ratios
170
+ if self.scale_major:
171
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
172
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
173
+ else:
174
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
175
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
176
+
177
+ # use float anchor and the anchor's center is aligned with the
178
+ # pixel center
179
+ base_anchors = [
180
+ x_center - 0.5 * ws, y_center - 0.5 * hs, x_center + 0.5 * ws,
181
+ y_center + 0.5 * hs
182
+ ]
183
+ base_anchors = torch.stack(base_anchors, dim=-1)
184
+
185
+ return base_anchors
186
+
187
+ def _meshgrid(self, x, y, row_major=True):
188
+ """Generate mesh grid of x and y.
189
+
190
+ Args:
191
+ x (torch.Tensor): Grids of x dimension.
192
+ y (torch.Tensor): Grids of y dimension.
193
+ row_major (bool, optional): Whether to return y grids first.
194
+ Defaults to True.
195
+
196
+ Returns:
197
+ tuple[torch.Tensor]: The mesh grids of x and y.
198
+ """
199
+ # use shape instead of len to keep tracing while exporting to onnx
200
+ xx = x.repeat(y.shape[0])
201
+ yy = y.view(-1, 1).repeat(1, x.shape[0]).view(-1)
202
+ if row_major:
203
+ return xx, yy
204
+ else:
205
+ return yy, xx
206
+
207
+ def grid_anchors(self, featmap_sizes, device='cuda'):
208
+ """Generate grid anchors in multiple feature levels.
209
+
210
+ Args:
211
+ featmap_sizes (list[tuple]): List of feature map sizes in
212
+ multiple feature levels.
213
+ device (str): Device where the anchors will be put on.
214
+
215
+ Return:
216
+ list[torch.Tensor]: Anchors in multiple feature levels. \
217
+ The sizes of each tensor should be [N, 4], where \
218
+ N = width * height * num_base_anchors, width and height \
219
+ are the sizes of the corresponding feature level, \
220
+ num_base_anchors is the number of anchors for that level.
221
+ """
222
+ assert self.num_levels == len(featmap_sizes)
223
+ multi_level_anchors = []
224
+ for i in range(self.num_levels):
225
+ anchors = self.single_level_grid_anchors(
226
+ self.base_anchors[i].to(device),
227
+ featmap_sizes[i],
228
+ self.strides[i],
229
+ device=device)
230
+ multi_level_anchors.append(anchors)
231
+ return multi_level_anchors
232
+
233
+ def single_level_grid_anchors(self,
234
+ base_anchors,
235
+ featmap_size,
236
+ stride=(16, 16),
237
+ device='cuda'):
238
+ """Generate grid anchors of a single level.
239
+
240
+ Note:
241
+ This function is usually called by method ``self.grid_anchors``.
242
+
243
+ Args:
244
+ base_anchors (torch.Tensor): The base anchors of a feature grid.
245
+ featmap_size (tuple[int]): Size of the feature maps.
246
+ stride (tuple[int], optional): Stride of the feature map in order
247
+ (w, h). Defaults to (16, 16).
248
+ device (str, optional): Device the tensor will be put on.
249
+ Defaults to 'cuda'.
250
+
251
+ Returns:
252
+ torch.Tensor: Anchors in the overall feature maps.
253
+ """
254
+ # keep as Tensor, so that we can covert to ONNX correctly
255
+ feat_h, feat_w = featmap_size
256
+ shift_x = torch.arange(0, feat_w, device=device) * stride[0]
257
+ shift_y = torch.arange(0, feat_h, device=device) * stride[1]
258
+
259
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
260
+ shifts = torch.stack([shift_xx, shift_yy, shift_xx, shift_yy], dim=-1)
261
+ shifts = shifts.type_as(base_anchors)
262
+ # first feat_w elements correspond to the first row of shifts
263
+ # add A anchors (1, A, 4) to K shifts (K, 1, 4) to get
264
+ # shifted anchors (K, A, 4), reshape to (K*A, 4)
265
+
266
+ all_anchors = base_anchors[None, :, :] + shifts[:, None, :]
267
+ all_anchors = all_anchors.view(-1, 4)
268
+ # first A rows correspond to A anchors of (0, 0) in feature map,
269
+ # then (0, 1), (0, 2), ...
270
+ return all_anchors
271
+
272
+ def valid_flags(self, featmap_sizes, pad_shape, device='cuda'):
273
+ """Generate valid flags of anchors in multiple feature levels.
274
+
275
+ Args:
276
+ featmap_sizes (list(tuple)): List of feature map sizes in
277
+ multiple feature levels.
278
+ pad_shape (tuple): The padded shape of the image.
279
+ device (str): Device where the anchors will be put on.
280
+
281
+ Return:
282
+ list(torch.Tensor): Valid flags of anchors in multiple levels.
283
+ """
284
+ assert self.num_levels == len(featmap_sizes)
285
+ multi_level_flags = []
286
+ for i in range(self.num_levels):
287
+ anchor_stride = self.strides[i]
288
+ feat_h, feat_w = featmap_sizes[i]
289
+ h, w = pad_shape[:2]
290
+ valid_feat_h = min(int(np.ceil(h / anchor_stride[1])), feat_h)
291
+ valid_feat_w = min(int(np.ceil(w / anchor_stride[0])), feat_w)
292
+ flags = self.single_level_valid_flags((feat_h, feat_w),
293
+ (valid_feat_h, valid_feat_w),
294
+ self.num_base_anchors[i],
295
+ device=device)
296
+ multi_level_flags.append(flags)
297
+ return multi_level_flags
298
+
299
+ def single_level_valid_flags(self,
300
+ featmap_size,
301
+ valid_size,
302
+ num_base_anchors,
303
+ device='cuda'):
304
+ """Generate the valid flags of anchor in a single feature map.
305
+
306
+ Args:
307
+ featmap_size (tuple[int]): The size of feature maps.
308
+ valid_size (tuple[int]): The valid size of the feature maps.
309
+ num_base_anchors (int): The number of base anchors.
310
+ device (str, optional): Device where the flags will be put on.
311
+ Defaults to 'cuda'.
312
+
313
+ Returns:
314
+ torch.Tensor: The valid flags of each anchor in a single level \
315
+ feature map.
316
+ """
317
+ feat_h, feat_w = featmap_size
318
+ valid_h, valid_w = valid_size
319
+ assert valid_h <= feat_h and valid_w <= feat_w
320
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
321
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
322
+ valid_x[:valid_w] = 1
323
+ valid_y[:valid_h] = 1
324
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
325
+ valid = valid_xx & valid_yy
326
+ valid = valid[:, None].expand(valid.size(0),
327
+ num_base_anchors).contiguous().view(-1)
328
+ return valid
329
+
330
+ def __repr__(self):
331
+ """str: a string that describes the module"""
332
+ indent_str = ' '
333
+ repr_str = self.__class__.__name__ + '(\n'
334
+ repr_str += f'{indent_str}strides={self.strides},\n'
335
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
336
+ repr_str += f'{indent_str}scales={self.scales},\n'
337
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
338
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
339
+ repr_str += f'{indent_str}octave_base_scale='
340
+ repr_str += f'{self.octave_base_scale},\n'
341
+ repr_str += f'{indent_str}scales_per_octave='
342
+ repr_str += f'{self.scales_per_octave},\n'
343
+ repr_str += f'{indent_str}num_levels={self.num_levels}\n'
344
+ repr_str += f'{indent_str}centers={self.centers},\n'
345
+ repr_str += f'{indent_str}center_offset={self.center_offset})'
346
+ return repr_str
347
+
348
+
349
+ @ANCHOR_GENERATORS.register_module()
350
+ class SSDAnchorGenerator(AnchorGenerator):
351
+ """Anchor generator for SSD.
352
+
353
+ Args:
354
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
355
+ in multiple feature levels.
356
+ ratios (list[float]): The list of ratios between the height and width
357
+ of anchors in a single level.
358
+ basesize_ratio_range (tuple(float)): Ratio range of anchors.
359
+ input_size (int): Size of feature map, 300 for SSD300,
360
+ 512 for SSD512.
361
+ scale_major (bool): Whether to multiply scales first when generating
362
+ base anchors. If true, the anchors in the same row will have the
363
+ same scales. It is always set to be False in SSD.
364
+ """
365
+
366
+ def __init__(self,
367
+ strides,
368
+ ratios,
369
+ basesize_ratio_range,
370
+ input_size=300,
371
+ scale_major=True):
372
+ assert len(strides) == len(ratios)
373
+ assert mmcv.is_tuple_of(basesize_ratio_range, float)
374
+
375
+ self.strides = [_pair(stride) for stride in strides]
376
+ self.input_size = input_size
377
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
378
+ for stride in self.strides]
379
+ self.basesize_ratio_range = basesize_ratio_range
380
+
381
+ # calculate anchor ratios and sizes
382
+ min_ratio, max_ratio = basesize_ratio_range
383
+ min_ratio = int(min_ratio * 100)
384
+ max_ratio = int(max_ratio * 100)
385
+ step = int(np.floor(max_ratio - min_ratio) / (self.num_levels - 2))
386
+ min_sizes = []
387
+ max_sizes = []
388
+ for ratio in range(int(min_ratio), int(max_ratio) + 1, step):
389
+ min_sizes.append(int(self.input_size * ratio / 100))
390
+ max_sizes.append(int(self.input_size * (ratio + step) / 100))
391
+ if self.input_size == 300:
392
+ if basesize_ratio_range[0] == 0.15: # SSD300 COCO
393
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
394
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
395
+ elif basesize_ratio_range[0] == 0.2: # SSD300 VOC
396
+ min_sizes.insert(0, int(self.input_size * 10 / 100))
397
+ max_sizes.insert(0, int(self.input_size * 20 / 100))
398
+ else:
399
+ raise ValueError(
400
+ 'basesize_ratio_range[0] should be either 0.15'
401
+ 'or 0.2 when input_size is 300, got '
402
+ f'{basesize_ratio_range[0]}.')
403
+ elif self.input_size == 512:
404
+ if basesize_ratio_range[0] == 0.1: # SSD512 COCO
405
+ min_sizes.insert(0, int(self.input_size * 4 / 100))
406
+ max_sizes.insert(0, int(self.input_size * 10 / 100))
407
+ elif basesize_ratio_range[0] == 0.15: # SSD512 VOC
408
+ min_sizes.insert(0, int(self.input_size * 7 / 100))
409
+ max_sizes.insert(0, int(self.input_size * 15 / 100))
410
+ else:
411
+ raise ValueError('basesize_ratio_range[0] should be either 0.1'
412
+ 'or 0.15 when input_size is 512, got'
413
+ f' {basesize_ratio_range[0]}.')
414
+ else:
415
+ raise ValueError('Only support 300 or 512 in SSDAnchorGenerator'
416
+ f', got {self.input_size}.')
417
+
418
+ anchor_ratios = []
419
+ anchor_scales = []
420
+ for k in range(len(self.strides)):
421
+ scales = [1., np.sqrt(max_sizes[k] / min_sizes[k])]
422
+ anchor_ratio = [1.]
423
+ for r in ratios[k]:
424
+ anchor_ratio += [1 / r, r] # 4 or 6 ratio
425
+ anchor_ratios.append(torch.Tensor(anchor_ratio))
426
+ anchor_scales.append(torch.Tensor(scales))
427
+
428
+ self.base_sizes = min_sizes
429
+ self.scales = anchor_scales
430
+ self.ratios = anchor_ratios
431
+ self.scale_major = scale_major
432
+ self.center_offset = 0
433
+ self.base_anchors = self.gen_base_anchors()
434
+
435
+ def gen_base_anchors(self):
436
+ """Generate base anchors.
437
+
438
+ Returns:
439
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
440
+ feature levels.
441
+ """
442
+ multi_level_base_anchors = []
443
+ for i, base_size in enumerate(self.base_sizes):
444
+ base_anchors = self.gen_single_level_base_anchors(
445
+ base_size,
446
+ scales=self.scales[i],
447
+ ratios=self.ratios[i],
448
+ center=self.centers[i])
449
+ indices = list(range(len(self.ratios[i])))
450
+ indices.insert(1, len(indices))
451
+ base_anchors = torch.index_select(base_anchors, 0,
452
+ torch.LongTensor(indices))
453
+ multi_level_base_anchors.append(base_anchors)
454
+ return multi_level_base_anchors
455
+
456
+ def __repr__(self):
457
+ """str: a string that describes the module"""
458
+ indent_str = ' '
459
+ repr_str = self.__class__.__name__ + '(\n'
460
+ repr_str += f'{indent_str}strides={self.strides},\n'
461
+ repr_str += f'{indent_str}scales={self.scales},\n'
462
+ repr_str += f'{indent_str}scale_major={self.scale_major},\n'
463
+ repr_str += f'{indent_str}input_size={self.input_size},\n'
464
+ repr_str += f'{indent_str}scales={self.scales},\n'
465
+ repr_str += f'{indent_str}ratios={self.ratios},\n'
466
+ repr_str += f'{indent_str}num_levels={self.num_levels},\n'
467
+ repr_str += f'{indent_str}base_sizes={self.base_sizes},\n'
468
+ repr_str += f'{indent_str}basesize_ratio_range='
469
+ repr_str += f'{self.basesize_ratio_range})'
470
+ return repr_str
471
+
472
+
473
+ @ANCHOR_GENERATORS.register_module()
474
+ class LegacyAnchorGenerator(AnchorGenerator):
475
+ """Legacy anchor generator used in MMDetection V1.x.
476
+
477
+ Note:
478
+ Difference to the V2.0 anchor generator:
479
+
480
+ 1. The center offset of V1.x anchors are set to be 0.5 rather than 0.
481
+ 2. The width/height are minused by 1 when calculating the anchors' \
482
+ centers and corners to meet the V1.x coordinate system.
483
+ 3. The anchors' corners are quantized.
484
+
485
+ Args:
486
+ strides (list[int] | list[tuple[int]]): Strides of anchors
487
+ in multiple feature levels.
488
+ ratios (list[float]): The list of ratios between the height and width
489
+ of anchors in a single level.
490
+ scales (list[int] | None): Anchor scales for anchors in a single level.
491
+ It cannot be set at the same time if `octave_base_scale` and
492
+ `scales_per_octave` are set.
493
+ base_sizes (list[int]): The basic sizes of anchors in multiple levels.
494
+ If None is given, strides will be used to generate base_sizes.
495
+ scale_major (bool): Whether to multiply scales first when generating
496
+ base anchors. If true, the anchors in the same row will have the
497
+ same scales. By default it is True in V2.0
498
+ octave_base_scale (int): The base scale of octave.
499
+ scales_per_octave (int): Number of scales for each octave.
500
+ `octave_base_scale` and `scales_per_octave` are usually used in
501
+ retinanet and the `scales` should be None when they are set.
502
+ centers (list[tuple[float, float]] | None): The centers of the anchor
503
+ relative to the feature grid center in multiple feature levels.
504
+ By default it is set to be None and not used. It a list of float
505
+ is given, this list will be used to shift the centers of anchors.
506
+ center_offset (float): The offset of center in propotion to anchors'
507
+ width and height. By default it is 0.5 in V2.0 but it should be 0.5
508
+ in v1.x models.
509
+
510
+ Examples:
511
+ >>> from mmdet.core import LegacyAnchorGenerator
512
+ >>> self = LegacyAnchorGenerator(
513
+ >>> [16], [1.], [1.], [9], center_offset=0.5)
514
+ >>> all_anchors = self.grid_anchors(((2, 2),), device='cpu')
515
+ >>> print(all_anchors)
516
+ [tensor([[ 0., 0., 8., 8.],
517
+ [16., 0., 24., 8.],
518
+ [ 0., 16., 8., 24.],
519
+ [16., 16., 24., 24.]])]
520
+ """
521
+
522
+ def gen_single_level_base_anchors(self,
523
+ base_size,
524
+ scales,
525
+ ratios,
526
+ center=None):
527
+ """Generate base anchors of a single level.
528
+
529
+ Note:
530
+ The width/height of anchors are minused by 1 when calculating \
531
+ the centers and corners to meet the V1.x coordinate system.
532
+
533
+ Args:
534
+ base_size (int | float): Basic size of an anchor.
535
+ scales (torch.Tensor): Scales of the anchor.
536
+ ratios (torch.Tensor): The ratio between between the height.
537
+ and width of anchors in a single level.
538
+ center (tuple[float], optional): The center of the base anchor
539
+ related to a single feature grid. Defaults to None.
540
+
541
+ Returns:
542
+ torch.Tensor: Anchors in a single-level feature map.
543
+ """
544
+ w = base_size
545
+ h = base_size
546
+ if center is None:
547
+ x_center = self.center_offset * (w - 1)
548
+ y_center = self.center_offset * (h - 1)
549
+ else:
550
+ x_center, y_center = center
551
+
552
+ h_ratios = torch.sqrt(ratios)
553
+ w_ratios = 1 / h_ratios
554
+ if self.scale_major:
555
+ ws = (w * w_ratios[:, None] * scales[None, :]).view(-1)
556
+ hs = (h * h_ratios[:, None] * scales[None, :]).view(-1)
557
+ else:
558
+ ws = (w * scales[:, None] * w_ratios[None, :]).view(-1)
559
+ hs = (h * scales[:, None] * h_ratios[None, :]).view(-1)
560
+
561
+ # use float anchor and the anchor's center is aligned with the
562
+ # pixel center
563
+ base_anchors = [
564
+ x_center - 0.5 * (ws - 1), y_center - 0.5 * (hs - 1),
565
+ x_center + 0.5 * (ws - 1), y_center + 0.5 * (hs - 1)
566
+ ]
567
+ base_anchors = torch.stack(base_anchors, dim=-1).round()
568
+
569
+ return base_anchors
570
+
571
+
572
+ @ANCHOR_GENERATORS.register_module()
573
+ class LegacySSDAnchorGenerator(SSDAnchorGenerator, LegacyAnchorGenerator):
574
+ """Legacy anchor generator used in MMDetection V1.x.
575
+
576
+ The difference between `LegacySSDAnchorGenerator` and `SSDAnchorGenerator`
577
+ can be found in `LegacyAnchorGenerator`.
578
+ """
579
+
580
+ def __init__(self,
581
+ strides,
582
+ ratios,
583
+ basesize_ratio_range,
584
+ input_size=300,
585
+ scale_major=True):
586
+ super(LegacySSDAnchorGenerator,
587
+ self).__init__(strides, ratios, basesize_ratio_range, input_size,
588
+ scale_major)
589
+ self.centers = [((stride - 1) / 2., (stride - 1) / 2.)
590
+ for stride in strides]
591
+ self.base_anchors = self.gen_base_anchors()
592
+
593
+
594
+ @ANCHOR_GENERATORS.register_module()
595
+ class YOLOAnchorGenerator(AnchorGenerator):
596
+ """Anchor generator for YOLO.
597
+
598
+ Args:
599
+ strides (list[int] | list[tuple[int, int]]): Strides of anchors
600
+ in multiple feature levels.
601
+ base_sizes (list[list[tuple[int, int]]]): The basic sizes
602
+ of anchors in multiple levels.
603
+ """
604
+
605
+ def __init__(self, strides, base_sizes):
606
+ self.strides = [_pair(stride) for stride in strides]
607
+ self.centers = [(stride[0] / 2., stride[1] / 2.)
608
+ for stride in self.strides]
609
+ self.base_sizes = []
610
+ num_anchor_per_level = len(base_sizes[0])
611
+ for base_sizes_per_level in base_sizes:
612
+ assert num_anchor_per_level == len(base_sizes_per_level)
613
+ self.base_sizes.append(
614
+ [_pair(base_size) for base_size in base_sizes_per_level])
615
+ self.base_anchors = self.gen_base_anchors()
616
+
617
+ @property
618
+ def num_levels(self):
619
+ """int: number of feature levels that the generator will be applied"""
620
+ return len(self.base_sizes)
621
+
622
+ def gen_base_anchors(self):
623
+ """Generate base anchors.
624
+
625
+ Returns:
626
+ list(torch.Tensor): Base anchors of a feature grid in multiple \
627
+ feature levels.
628
+ """
629
+ multi_level_base_anchors = []
630
+ for i, base_sizes_per_level in enumerate(self.base_sizes):
631
+ center = None
632
+ if self.centers is not None:
633
+ center = self.centers[i]
634
+ multi_level_base_anchors.append(
635
+ self.gen_single_level_base_anchors(base_sizes_per_level,
636
+ center))
637
+ return multi_level_base_anchors
638
+
639
+ def gen_single_level_base_anchors(self, base_sizes_per_level, center=None):
640
+ """Generate base anchors of a single level.
641
+
642
+ Args:
643
+ base_sizes_per_level (list[tuple[int, int]]): Basic sizes of
644
+ anchors.
645
+ center (tuple[float], optional): The center of the base anchor
646
+ related to a single feature grid. Defaults to None.
647
+
648
+ Returns:
649
+ torch.Tensor: Anchors in a single-level feature maps.
650
+ """
651
+ x_center, y_center = center
652
+ base_anchors = []
653
+ for base_size in base_sizes_per_level:
654
+ w, h = base_size
655
+
656
+ # use float anchor and the anchor's center is aligned with the
657
+ # pixel center
658
+ base_anchor = torch.Tensor([
659
+ x_center - 0.5 * w, y_center - 0.5 * h, x_center + 0.5 * w,
660
+ y_center + 0.5 * h
661
+ ])
662
+ base_anchors.append(base_anchor)
663
+ base_anchors = torch.stack(base_anchors, dim=0)
664
+
665
+ return base_anchors
666
+
667
+ def responsible_flags(self, featmap_sizes, gt_bboxes, device='cuda'):
668
+ """Generate responsible anchor flags of grid cells in multiple scales.
669
+
670
+ Args:
671
+ featmap_sizes (list(tuple)): List of feature map sizes in multiple
672
+ feature levels.
673
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
674
+ device (str): Device where the anchors will be put on.
675
+
676
+ Return:
677
+ list(torch.Tensor): responsible flags of anchors in multiple level
678
+ """
679
+ assert self.num_levels == len(featmap_sizes)
680
+ multi_level_responsible_flags = []
681
+ for i in range(self.num_levels):
682
+ anchor_stride = self.strides[i]
683
+ flags = self.single_level_responsible_flags(
684
+ featmap_sizes[i],
685
+ gt_bboxes,
686
+ anchor_stride,
687
+ self.num_base_anchors[i],
688
+ device=device)
689
+ multi_level_responsible_flags.append(flags)
690
+ return multi_level_responsible_flags
691
+
692
+ def single_level_responsible_flags(self,
693
+ featmap_size,
694
+ gt_bboxes,
695
+ stride,
696
+ num_base_anchors,
697
+ device='cuda'):
698
+ """Generate the responsible flags of anchor in a single feature map.
699
+
700
+ Args:
701
+ featmap_size (tuple[int]): The size of feature maps.
702
+ gt_bboxes (Tensor): Ground truth boxes, shape (n, 4).
703
+ stride (tuple(int)): stride of current level
704
+ num_base_anchors (int): The number of base anchors.
705
+ device (str, optional): Device where the flags will be put on.
706
+ Defaults to 'cuda'.
707
+
708
+ Returns:
709
+ torch.Tensor: The valid flags of each anchor in a single level \
710
+ feature map.
711
+ """
712
+ feat_h, feat_w = featmap_size
713
+ gt_bboxes_cx = ((gt_bboxes[:, 0] + gt_bboxes[:, 2]) * 0.5).to(device)
714
+ gt_bboxes_cy = ((gt_bboxes[:, 1] + gt_bboxes[:, 3]) * 0.5).to(device)
715
+ gt_bboxes_grid_x = torch.floor(gt_bboxes_cx / stride[0]).long()
716
+ gt_bboxes_grid_y = torch.floor(gt_bboxes_cy / stride[1]).long()
717
+
718
+ # row major indexing
719
+ gt_bboxes_grid_idx = gt_bboxes_grid_y * feat_w + gt_bboxes_grid_x
720
+
721
+ responsible_grid = torch.zeros(
722
+ feat_h * feat_w, dtype=torch.uint8, device=device)
723
+ responsible_grid[gt_bboxes_grid_idx] = 1
724
+
725
+ responsible_grid = responsible_grid[:, None].expand(
726
+ responsible_grid.size(0), num_base_anchors).contiguous().view(-1)
727
+ return responsible_grid
mmdet/core/anchor/builder.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
1
+ from mmcv.utils import Registry, build_from_cfg
2
+
3
+ ANCHOR_GENERATORS = Registry('Anchor generator')
4
+
5
+
6
+ def build_anchor_generator(cfg, default_args=None):
7
+ return build_from_cfg(cfg, ANCHOR_GENERATORS, default_args)
mmdet/core/anchor/point_generator.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from .builder import ANCHOR_GENERATORS
4
+
5
+
6
+ @ANCHOR_GENERATORS.register_module()
7
+ class PointGenerator(object):
8
+
9
+ def _meshgrid(self, x, y, row_major=True):
10
+ xx = x.repeat(len(y))
11
+ yy = y.view(-1, 1).repeat(1, len(x)).view(-1)
12
+ if row_major:
13
+ return xx, yy
14
+ else:
15
+ return yy, xx
16
+
17
+ def grid_points(self, featmap_size, stride=16, device='cuda'):
18
+ feat_h, feat_w = featmap_size
19
+ shift_x = torch.arange(0., feat_w, device=device) * stride
20
+ shift_y = torch.arange(0., feat_h, device=device) * stride
21
+ shift_xx, shift_yy = self._meshgrid(shift_x, shift_y)
22
+ stride = shift_x.new_full((shift_xx.shape[0], ), stride)
23
+ shifts = torch.stack([shift_xx, shift_yy, stride], dim=-1)
24
+ all_points = shifts.to(device)
25
+ return all_points
26
+
27
+ def valid_flags(self, featmap_size, valid_size, device='cuda'):
28
+ feat_h, feat_w = featmap_size
29
+ valid_h, valid_w = valid_size
30
+ assert valid_h <= feat_h and valid_w <= feat_w
31
+ valid_x = torch.zeros(feat_w, dtype=torch.bool, device=device)
32
+ valid_y = torch.zeros(feat_h, dtype=torch.bool, device=device)
33
+ valid_x[:valid_w] = 1
34
+ valid_y[:valid_h] = 1
35
+ valid_xx, valid_yy = self._meshgrid(valid_x, valid_y)
36
+ valid = valid_xx & valid_yy
37
+ return valid
mmdet/core/anchor/utils.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def images_to_levels(target, num_levels):
5
+ """Convert targets by image to targets by feature level.
6
+
7
+ [target_img0, target_img1] -> [target_level0, target_level1, ...]
8
+ """
9
+ target = torch.stack(target, 0)
10
+ level_targets = []
11
+ start = 0
12
+ for n in num_levels:
13
+ end = start + n
14
+ # level_targets.append(target[:, start:end].squeeze(0))
15
+ level_targets.append(target[:, start:end])
16
+ start = end
17
+ return level_targets
18
+
19
+
20
+ def anchor_inside_flags(flat_anchors,
21
+ valid_flags,
22
+ img_shape,
23
+ allowed_border=0):
24
+ """Check whether the anchors are inside the border.
25
+
26
+ Args:
27
+ flat_anchors (torch.Tensor): Flatten anchors, shape (n, 4).
28
+ valid_flags (torch.Tensor): An existing valid flags of anchors.
29
+ img_shape (tuple(int)): Shape of current image.
30
+ allowed_border (int, optional): The border to allow the valid anchor.
31
+ Defaults to 0.
32
+
33
+ Returns:
34
+ torch.Tensor: Flags indicating whether the anchors are inside a \
35
+ valid range.
36
+ """
37
+ img_h, img_w = img_shape[:2]
38
+ if allowed_border >= 0:
39
+ inside_flags = valid_flags & \
40
+ (flat_anchors[:, 0] >= -allowed_border) & \
41
+ (flat_anchors[:, 1] >= -allowed_border) & \
42
+ (flat_anchors[:, 2] < img_w + allowed_border) & \
43
+ (flat_anchors[:, 3] < img_h + allowed_border)
44
+ else:
45
+ inside_flags = valid_flags
46
+ return inside_flags
47
+
48
+
49
+ def calc_region(bbox, ratio, featmap_size=None):
50
+ """Calculate a proportional bbox region.
51
+
52
+ The bbox center are fixed and the new h' and w' is h * ratio and w * ratio.
53
+
54
+ Args:
55
+ bbox (Tensor): Bboxes to calculate regions, shape (n, 4).
56
+ ratio (float): Ratio of the output region.
57
+ featmap_size (tuple): Feature map size used for clipping the boundary.
58
+
59
+ Returns:
60
+ tuple: x1, y1, x2, y2
61
+ """
62
+ x1 = torch.round((1 - ratio) * bbox[0] + ratio * bbox[2]).long()
63
+ y1 = torch.round((1 - ratio) * bbox[1] + ratio * bbox[3]).long()
64
+ x2 = torch.round(ratio * bbox[0] + (1 - ratio) * bbox[2]).long()
65
+ y2 = torch.round(ratio * bbox[1] + (1 - ratio) * bbox[3]).long()
66
+ if featmap_size is not None:
67
+ x1 = x1.clamp(min=0, max=featmap_size[1])
68
+ y1 = y1.clamp(min=0, max=featmap_size[0])
69
+ x2 = x2.clamp(min=0, max=featmap_size[1])
70
+ y2 = y2.clamp(min=0, max=featmap_size[0])
71
+ return (x1, y1, x2, y2)
mmdet/core/bbox/__init__.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .assigners import (AssignResult, BaseAssigner, CenterRegionAssigner,
2
+ MaxIoUAssigner, RegionAssigner)
3
+ from .builder import build_assigner, build_bbox_coder, build_sampler
4
+ from .coder import (BaseBBoxCoder, DeltaXYWHBBoxCoder, PseudoBBoxCoder,
5
+ TBLRBBoxCoder)
6
+ from .iou_calculators import BboxOverlaps2D, bbox_overlaps
7
+ from .samplers import (BaseSampler, CombinedSampler,
8
+ InstanceBalancedPosSampler, IoUBalancedNegSampler,
9
+ OHEMSampler, PseudoSampler, RandomSampler,
10
+ SamplingResult, ScoreHLRSampler)
11
+ from .transforms import (bbox2distance, bbox2result, bbox2roi,
12
+ bbox_cxcywh_to_xyxy, bbox_flip, bbox_mapping,
13
+ bbox_mapping_back, bbox_rescale, bbox_xyxy_to_cxcywh,
14
+ distance2bbox, roi2bbox)
15
+
16
+ __all__ = [
17
+ 'bbox_overlaps', 'BboxOverlaps2D', 'BaseAssigner', 'MaxIoUAssigner',
18
+ 'AssignResult', 'BaseSampler', 'PseudoSampler', 'RandomSampler',
19
+ 'InstanceBalancedPosSampler', 'IoUBalancedNegSampler', 'CombinedSampler',
20
+ 'OHEMSampler', 'SamplingResult', 'ScoreHLRSampler', 'build_assigner',
21
+ 'build_sampler', 'bbox_flip', 'bbox_mapping', 'bbox_mapping_back',
22
+ 'bbox2roi', 'roi2bbox', 'bbox2result', 'distance2bbox', 'bbox2distance',
23
+ 'build_bbox_coder', 'BaseBBoxCoder', 'PseudoBBoxCoder',
24
+ 'DeltaXYWHBBoxCoder', 'TBLRBBoxCoder', 'CenterRegionAssigner',
25
+ 'bbox_rescale', 'bbox_cxcywh_to_xyxy', 'bbox_xyxy_to_cxcywh',
26
+ 'RegionAssigner'
27
+ ]
mmdet/core/bbox/assigners/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .approx_max_iou_assigner import ApproxMaxIoUAssigner
2
+ from .assign_result import AssignResult
3
+ from .atss_assigner import ATSSAssigner
4
+ from .base_assigner import BaseAssigner
5
+ from .center_region_assigner import CenterRegionAssigner
6
+ from .grid_assigner import GridAssigner
7
+ from .hungarian_assigner import HungarianAssigner
8
+ from .max_iou_assigner import MaxIoUAssigner
9
+ from .point_assigner import PointAssigner
10
+ from .region_assigner import RegionAssigner
11
+
12
+ __all__ = [
13
+ 'BaseAssigner', 'MaxIoUAssigner', 'ApproxMaxIoUAssigner', 'AssignResult',
14
+ 'PointAssigner', 'ATSSAssigner', 'CenterRegionAssigner', 'GridAssigner',
15
+ 'HungarianAssigner', 'RegionAssigner'
16
+ ]
mmdet/core/bbox/assigners/approx_max_iou_assigner.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .max_iou_assigner import MaxIoUAssigner
6
+
7
+
8
+ @BBOX_ASSIGNERS.register_module()
9
+ class ApproxMaxIoUAssigner(MaxIoUAssigner):
10
+ """Assign a corresponding gt bbox or background to each bbox.
11
+
12
+ Each proposals will be assigned with an integer indicating the ground truth
13
+ index. (semi-positive index: gt label (0-based), -1: background)
14
+
15
+ - -1: negative sample, no assigned gt
16
+ - semi-positive integer: positive sample, index (0-based) of assigned gt
17
+
18
+ Args:
19
+ pos_iou_thr (float): IoU threshold for positive bboxes.
20
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
21
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
22
+ positive bbox. Positive samples can have smaller IoU than
23
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
24
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
25
+ highest overlap with some gt to that gt.
26
+ ignore_iof_thr (float): IoF threshold for ignoring bboxes (if
27
+ `gt_bboxes_ignore` is specified). Negative values mean not
28
+ ignoring any bboxes.
29
+ ignore_wrt_candidates (bool): Whether to compute the iof between
30
+ `bboxes` and `gt_bboxes_ignore`, or the contrary.
31
+ match_low_quality (bool): Whether to allow quality matches. This is
32
+ usually allowed for RPN and single stage detectors, but not allowed
33
+ in the second stage.
34
+ gpu_assign_thr (int): The upper bound of the number of GT for GPU
35
+ assign. When the number of gt is above this threshold, will assign
36
+ on CPU device. Negative values mean not assign on CPU.
37
+ """
38
+
39
+ def __init__(self,
40
+ pos_iou_thr,
41
+ neg_iou_thr,
42
+ min_pos_iou=.0,
43
+ gt_max_assign_all=True,
44
+ ignore_iof_thr=-1,
45
+ ignore_wrt_candidates=True,
46
+ match_low_quality=True,
47
+ gpu_assign_thr=-1,
48
+ iou_calculator=dict(type='BboxOverlaps2D')):
49
+ self.pos_iou_thr = pos_iou_thr
50
+ self.neg_iou_thr = neg_iou_thr
51
+ self.min_pos_iou = min_pos_iou
52
+ self.gt_max_assign_all = gt_max_assign_all
53
+ self.ignore_iof_thr = ignore_iof_thr
54
+ self.ignore_wrt_candidates = ignore_wrt_candidates
55
+ self.gpu_assign_thr = gpu_assign_thr
56
+ self.match_low_quality = match_low_quality
57
+ self.iou_calculator = build_iou_calculator(iou_calculator)
58
+
59
+ def assign(self,
60
+ approxs,
61
+ squares,
62
+ approxs_per_octave,
63
+ gt_bboxes,
64
+ gt_bboxes_ignore=None,
65
+ gt_labels=None):
66
+ """Assign gt to approxs.
67
+
68
+ This method assign a gt bbox to each group of approxs (bboxes),
69
+ each group of approxs is represent by a base approx (bbox) and
70
+ will be assigned with -1, or a semi-positive number.
71
+ background_label (-1) means negative sample,
72
+ semi-positive number is the index (0-based) of assigned gt.
73
+ The assignment is done in following steps, the order matters.
74
+
75
+ 1. assign every bbox to background_label (-1)
76
+ 2. use the max IoU of each group of approxs to assign
77
+ 2. assign proposals whose iou with all gts < neg_iou_thr to background
78
+ 3. for each bbox, if the iou with its nearest gt >= pos_iou_thr,
79
+ assign it to that bbox
80
+ 4. for each gt bbox, assign its nearest proposals (may be more than
81
+ one) to itself
82
+
83
+ Args:
84
+ approxs (Tensor): Bounding boxes to be assigned,
85
+ shape(approxs_per_octave*n, 4).
86
+ squares (Tensor): Base Bounding boxes to be assigned,
87
+ shape(n, 4).
88
+ approxs_per_octave (int): number of approxs per octave
89
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
90
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
91
+ labelled as `ignored`, e.g., crowd boxes in COCO.
92
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
93
+
94
+ Returns:
95
+ :obj:`AssignResult`: The assign result.
96
+ """
97
+ num_squares = squares.size(0)
98
+ num_gts = gt_bboxes.size(0)
99
+
100
+ if num_squares == 0 or num_gts == 0:
101
+ # No predictions and/or truth, return empty assignment
102
+ overlaps = approxs.new(num_gts, num_squares)
103
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
104
+ return assign_result
105
+
106
+ # re-organize anchors by approxs_per_octave x num_squares
107
+ approxs = torch.transpose(
108
+ approxs.view(num_squares, approxs_per_octave, 4), 0,
109
+ 1).contiguous().view(-1, 4)
110
+ assign_on_cpu = True if (self.gpu_assign_thr > 0) and (
111
+ num_gts > self.gpu_assign_thr) else False
112
+ # compute overlap and assign gt on CPU when number of GT is large
113
+ if assign_on_cpu:
114
+ device = approxs.device
115
+ approxs = approxs.cpu()
116
+ gt_bboxes = gt_bboxes.cpu()
117
+ if gt_bboxes_ignore is not None:
118
+ gt_bboxes_ignore = gt_bboxes_ignore.cpu()
119
+ if gt_labels is not None:
120
+ gt_labels = gt_labels.cpu()
121
+ all_overlaps = self.iou_calculator(approxs, gt_bboxes)
122
+
123
+ overlaps, _ = all_overlaps.view(approxs_per_octave, num_squares,
124
+ num_gts).max(dim=0)
125
+ overlaps = torch.transpose(overlaps, 0, 1)
126
+
127
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
128
+ and gt_bboxes_ignore.numel() > 0 and squares.numel() > 0):
129
+ if self.ignore_wrt_candidates:
130
+ ignore_overlaps = self.iou_calculator(
131
+ squares, gt_bboxes_ignore, mode='iof')
132
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
133
+ else:
134
+ ignore_overlaps = self.iou_calculator(
135
+ gt_bboxes_ignore, squares, mode='iof')
136
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=0)
137
+ overlaps[:, ignore_max_overlaps > self.ignore_iof_thr] = -1
138
+
139
+ assign_result = self.assign_wrt_overlaps(overlaps, gt_labels)
140
+ if assign_on_cpu:
141
+ assign_result.gt_inds = assign_result.gt_inds.to(device)
142
+ assign_result.max_overlaps = assign_result.max_overlaps.to(device)
143
+ if assign_result.labels is not None:
144
+ assign_result.labels = assign_result.labels.to(device)
145
+ return assign_result
mmdet/core/bbox/assigners/assign_result.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from mmdet.utils import util_mixins
4
+
5
+
6
+ class AssignResult(util_mixins.NiceRepr):
7
+ """Stores assignments between predicted and truth boxes.
8
+
9
+ Attributes:
10
+ num_gts (int): the number of truth boxes considered when computing this
11
+ assignment
12
+
13
+ gt_inds (LongTensor): for each predicted box indicates the 1-based
14
+ index of the assigned truth box. 0 means unassigned and -1 means
15
+ ignore.
16
+
17
+ max_overlaps (FloatTensor): the iou between the predicted box and its
18
+ assigned truth box.
19
+
20
+ labels (None | LongTensor): If specified, for each predicted box
21
+ indicates the category label of the assigned truth box.
22
+
23
+ Example:
24
+ >>> # An assign result between 4 predicted boxes and 9 true boxes
25
+ >>> # where only two boxes were assigned.
26
+ >>> num_gts = 9
27
+ >>> max_overlaps = torch.LongTensor([0, .5, .9, 0])
28
+ >>> gt_inds = torch.LongTensor([-1, 1, 2, 0])
29
+ >>> labels = torch.LongTensor([0, 3, 4, 0])
30
+ >>> self = AssignResult(num_gts, gt_inds, max_overlaps, labels)
31
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
32
+ <AssignResult(num_gts=9, gt_inds.shape=(4,), max_overlaps.shape=(4,),
33
+ labels.shape=(4,))>
34
+ >>> # Force addition of gt labels (when adding gt as proposals)
35
+ >>> new_labels = torch.LongTensor([3, 4, 5])
36
+ >>> self.add_gt_(new_labels)
37
+ >>> print(str(self)) # xdoctest: +IGNORE_WANT
38
+ <AssignResult(num_gts=9, gt_inds.shape=(7,), max_overlaps.shape=(7,),
39
+ labels.shape=(7,))>
40
+ """
41
+
42
+ def __init__(self, num_gts, gt_inds, max_overlaps, labels=None):
43
+ self.num_gts = num_gts
44
+ self.gt_inds = gt_inds
45
+ self.max_overlaps = max_overlaps
46
+ self.labels = labels
47
+ # Interface for possible user-defined properties
48
+ self._extra_properties = {}
49
+
50
+ @property
51
+ def num_preds(self):
52
+ """int: the number of predictions in this assignment"""
53
+ return len(self.gt_inds)
54
+
55
+ def set_extra_property(self, key, value):
56
+ """Set user-defined new property."""
57
+ assert key not in self.info
58
+ self._extra_properties[key] = value
59
+
60
+ def get_extra_property(self, key):
61
+ """Get user-defined property."""
62
+ return self._extra_properties.get(key, None)
63
+
64
+ @property
65
+ def info(self):
66
+ """dict: a dictionary of info about the object"""
67
+ basic_info = {
68
+ 'num_gts': self.num_gts,
69
+ 'num_preds': self.num_preds,
70
+ 'gt_inds': self.gt_inds,
71
+ 'max_overlaps': self.max_overlaps,
72
+ 'labels': self.labels,
73
+ }
74
+ basic_info.update(self._extra_properties)
75
+ return basic_info
76
+
77
+ def __nice__(self):
78
+ """str: a "nice" summary string describing this assign result"""
79
+ parts = []
80
+ parts.append(f'num_gts={self.num_gts!r}')
81
+ if self.gt_inds is None:
82
+ parts.append(f'gt_inds={self.gt_inds!r}')
83
+ else:
84
+ parts.append(f'gt_inds.shape={tuple(self.gt_inds.shape)!r}')
85
+ if self.max_overlaps is None:
86
+ parts.append(f'max_overlaps={self.max_overlaps!r}')
87
+ else:
88
+ parts.append('max_overlaps.shape='
89
+ f'{tuple(self.max_overlaps.shape)!r}')
90
+ if self.labels is None:
91
+ parts.append(f'labels={self.labels!r}')
92
+ else:
93
+ parts.append(f'labels.shape={tuple(self.labels.shape)!r}')
94
+ return ', '.join(parts)
95
+
96
+ @classmethod
97
+ def random(cls, **kwargs):
98
+ """Create random AssignResult for tests or debugging.
99
+
100
+ Args:
101
+ num_preds: number of predicted boxes
102
+ num_gts: number of true boxes
103
+ p_ignore (float): probability of a predicted box assinged to an
104
+ ignored truth
105
+ p_assigned (float): probability of a predicted box not being
106
+ assigned
107
+ p_use_label (float | bool): with labels or not
108
+ rng (None | int | numpy.random.RandomState): seed or state
109
+
110
+ Returns:
111
+ :obj:`AssignResult`: Randomly generated assign results.
112
+
113
+ Example:
114
+ >>> from mmdet.core.bbox.assigners.assign_result import * # NOQA
115
+ >>> self = AssignResult.random()
116
+ >>> print(self.info)
117
+ """
118
+ from mmdet.core.bbox import demodata
119
+ rng = demodata.ensure_rng(kwargs.get('rng', None))
120
+
121
+ num_gts = kwargs.get('num_gts', None)
122
+ num_preds = kwargs.get('num_preds', None)
123
+ p_ignore = kwargs.get('p_ignore', 0.3)
124
+ p_assigned = kwargs.get('p_assigned', 0.7)
125
+ p_use_label = kwargs.get('p_use_label', 0.5)
126
+ num_classes = kwargs.get('p_use_label', 3)
127
+
128
+ if num_gts is None:
129
+ num_gts = rng.randint(0, 8)
130
+ if num_preds is None:
131
+ num_preds = rng.randint(0, 16)
132
+
133
+ if num_gts == 0:
134
+ max_overlaps = torch.zeros(num_preds, dtype=torch.float32)
135
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
136
+ if p_use_label is True or p_use_label < rng.rand():
137
+ labels = torch.zeros(num_preds, dtype=torch.int64)
138
+ else:
139
+ labels = None
140
+ else:
141
+ import numpy as np
142
+ # Create an overlap for each predicted box
143
+ max_overlaps = torch.from_numpy(rng.rand(num_preds))
144
+
145
+ # Construct gt_inds for each predicted box
146
+ is_assigned = torch.from_numpy(rng.rand(num_preds) < p_assigned)
147
+ # maximum number of assignments constraints
148
+ n_assigned = min(num_preds, min(num_gts, is_assigned.sum()))
149
+
150
+ assigned_idxs = np.where(is_assigned)[0]
151
+ rng.shuffle(assigned_idxs)
152
+ assigned_idxs = assigned_idxs[0:n_assigned]
153
+ assigned_idxs.sort()
154
+
155
+ is_assigned[:] = 0
156
+ is_assigned[assigned_idxs] = True
157
+
158
+ is_ignore = torch.from_numpy(
159
+ rng.rand(num_preds) < p_ignore) & is_assigned
160
+
161
+ gt_inds = torch.zeros(num_preds, dtype=torch.int64)
162
+
163
+ true_idxs = np.arange(num_gts)
164
+ rng.shuffle(true_idxs)
165
+ true_idxs = torch.from_numpy(true_idxs)
166
+ gt_inds[is_assigned] = true_idxs[:n_assigned]
167
+
168
+ gt_inds = torch.from_numpy(
169
+ rng.randint(1, num_gts + 1, size=num_preds))
170
+ gt_inds[is_ignore] = -1
171
+ gt_inds[~is_assigned] = 0
172
+ max_overlaps[~is_assigned] = 0
173
+
174
+ if p_use_label is True or p_use_label < rng.rand():
175
+ if num_classes == 0:
176
+ labels = torch.zeros(num_preds, dtype=torch.int64)
177
+ else:
178
+ labels = torch.from_numpy(
179
+ # remind that we set FG labels to [0, num_class-1]
180
+ # since mmdet v2.0
181
+ # BG cat_id: num_class
182
+ rng.randint(0, num_classes, size=num_preds))
183
+ labels[~is_assigned] = 0
184
+ else:
185
+ labels = None
186
+
187
+ self = cls(num_gts, gt_inds, max_overlaps, labels)
188
+ return self
189
+
190
+ def add_gt_(self, gt_labels):
191
+ """Add ground truth as assigned results.
192
+
193
+ Args:
194
+ gt_labels (torch.Tensor): Labels of gt boxes
195
+ """
196
+ self_inds = torch.arange(
197
+ 1, len(gt_labels) + 1, dtype=torch.long, device=gt_labels.device)
198
+ self.gt_inds = torch.cat([self_inds, self.gt_inds])
199
+
200
+ self.max_overlaps = torch.cat(
201
+ [self.max_overlaps.new_ones(len(gt_labels)), self.max_overlaps])
202
+
203
+ if self.labels is not None:
204
+ self.labels = torch.cat([gt_labels, self.labels])
mmdet/core/bbox/assigners/atss_assigner.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ @BBOX_ASSIGNERS.register_module()
10
+ class ATSSAssigner(BaseAssigner):
11
+ """Assign a corresponding gt bbox or background to each bbox.
12
+
13
+ Each proposals will be assigned with `0` or a positive integer
14
+ indicating the ground truth index.
15
+
16
+ - 0: negative sample, no assigned gt
17
+ - positive integer: positive sample, index (1-based) of assigned gt
18
+
19
+ Args:
20
+ topk (float): number of bbox selected in each level
21
+ """
22
+
23
+ def __init__(self,
24
+ topk,
25
+ iou_calculator=dict(type='BboxOverlaps2D'),
26
+ ignore_iof_thr=-1):
27
+ self.topk = topk
28
+ self.iou_calculator = build_iou_calculator(iou_calculator)
29
+ self.ignore_iof_thr = ignore_iof_thr
30
+
31
+ # https://github.com/sfzhang15/ATSS/blob/master/atss_core/modeling/rpn/atss/loss.py
32
+
33
+ def assign(self,
34
+ bboxes,
35
+ num_level_bboxes,
36
+ gt_bboxes,
37
+ gt_bboxes_ignore=None,
38
+ gt_labels=None):
39
+ """Assign gt to bboxes.
40
+
41
+ The assignment is done in following steps
42
+
43
+ 1. compute iou between all bbox (bbox of all pyramid levels) and gt
44
+ 2. compute center distance between all bbox and gt
45
+ 3. on each pyramid level, for each gt, select k bbox whose center
46
+ are closest to the gt center, so we total select k*l bbox as
47
+ candidates for each gt
48
+ 4. get corresponding iou for the these candidates, and compute the
49
+ mean and std, set mean + std as the iou threshold
50
+ 5. select these candidates whose iou are greater than or equal to
51
+ the threshold as positive
52
+ 6. limit the positive sample's center in gt
53
+
54
+
55
+ Args:
56
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
57
+ num_level_bboxes (List): num of bboxes in each level
58
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
59
+ gt_bboxes_ignore (Tensor, optional): Ground truth bboxes that are
60
+ labelled as `ignored`, e.g., crowd boxes in COCO.
61
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
62
+
63
+ Returns:
64
+ :obj:`AssignResult`: The assign result.
65
+ """
66
+ INF = 100000000
67
+ bboxes = bboxes[:, :4]
68
+ num_gt, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
69
+
70
+ # compute iou between all bbox and gt
71
+ overlaps = self.iou_calculator(bboxes, gt_bboxes)
72
+
73
+ # assign 0 by default
74
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
75
+ 0,
76
+ dtype=torch.long)
77
+
78
+ if num_gt == 0 or num_bboxes == 0:
79
+ # No ground truth or boxes, return empty assignment
80
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
81
+ if num_gt == 0:
82
+ # No truth, assign everything to background
83
+ assigned_gt_inds[:] = 0
84
+ if gt_labels is None:
85
+ assigned_labels = None
86
+ else:
87
+ assigned_labels = overlaps.new_full((num_bboxes, ),
88
+ -1,
89
+ dtype=torch.long)
90
+ return AssignResult(
91
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
92
+
93
+ # compute center distance between all bbox and gt
94
+ gt_cx = (gt_bboxes[:, 0] + gt_bboxes[:, 2]) / 2.0
95
+ gt_cy = (gt_bboxes[:, 1] + gt_bboxes[:, 3]) / 2.0
96
+ gt_points = torch.stack((gt_cx, gt_cy), dim=1)
97
+
98
+ bboxes_cx = (bboxes[:, 0] + bboxes[:, 2]) / 2.0
99
+ bboxes_cy = (bboxes[:, 1] + bboxes[:, 3]) / 2.0
100
+ bboxes_points = torch.stack((bboxes_cx, bboxes_cy), dim=1)
101
+
102
+ distances = (bboxes_points[:, None, :] -
103
+ gt_points[None, :, :]).pow(2).sum(-1).sqrt()
104
+
105
+ if (self.ignore_iof_thr > 0 and gt_bboxes_ignore is not None
106
+ and gt_bboxes_ignore.numel() > 0 and bboxes.numel() > 0):
107
+ ignore_overlaps = self.iou_calculator(
108
+ bboxes, gt_bboxes_ignore, mode='iof')
109
+ ignore_max_overlaps, _ = ignore_overlaps.max(dim=1)
110
+ ignore_idxs = ignore_max_overlaps > self.ignore_iof_thr
111
+ distances[ignore_idxs, :] = INF
112
+ assigned_gt_inds[ignore_idxs] = -1
113
+
114
+ # Selecting candidates based on the center distance
115
+ candidate_idxs = []
116
+ start_idx = 0
117
+ for level, bboxes_per_level in enumerate(num_level_bboxes):
118
+ # on each pyramid level, for each gt,
119
+ # select k bbox whose center are closest to the gt center
120
+ end_idx = start_idx + bboxes_per_level
121
+ distances_per_level = distances[start_idx:end_idx, :]
122
+ selectable_k = min(self.topk, bboxes_per_level)
123
+ _, topk_idxs_per_level = distances_per_level.topk(
124
+ selectable_k, dim=0, largest=False)
125
+ candidate_idxs.append(topk_idxs_per_level + start_idx)
126
+ start_idx = end_idx
127
+ candidate_idxs = torch.cat(candidate_idxs, dim=0)
128
+
129
+ # get corresponding iou for the these candidates, and compute the
130
+ # mean and std, set mean + std as the iou threshold
131
+ candidate_overlaps = overlaps[candidate_idxs, torch.arange(num_gt)]
132
+ overlaps_mean_per_gt = candidate_overlaps.mean(0)
133
+ overlaps_std_per_gt = candidate_overlaps.std(0)
134
+ overlaps_thr_per_gt = overlaps_mean_per_gt + overlaps_std_per_gt
135
+
136
+ is_pos = candidate_overlaps >= overlaps_thr_per_gt[None, :]
137
+
138
+ # limit the positive sample's center in gt
139
+ for gt_idx in range(num_gt):
140
+ candidate_idxs[:, gt_idx] += gt_idx * num_bboxes
141
+ ep_bboxes_cx = bboxes_cx.view(1, -1).expand(
142
+ num_gt, num_bboxes).contiguous().view(-1)
143
+ ep_bboxes_cy = bboxes_cy.view(1, -1).expand(
144
+ num_gt, num_bboxes).contiguous().view(-1)
145
+ candidate_idxs = candidate_idxs.view(-1)
146
+
147
+ # calculate the left, top, right, bottom distance between positive
148
+ # bbox center and gt side
149
+ l_ = ep_bboxes_cx[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 0]
150
+ t_ = ep_bboxes_cy[candidate_idxs].view(-1, num_gt) - gt_bboxes[:, 1]
151
+ r_ = gt_bboxes[:, 2] - ep_bboxes_cx[candidate_idxs].view(-1, num_gt)
152
+ b_ = gt_bboxes[:, 3] - ep_bboxes_cy[candidate_idxs].view(-1, num_gt)
153
+ is_in_gts = torch.stack([l_, t_, r_, b_], dim=1).min(dim=1)[0] > 0.01
154
+ is_pos = is_pos & is_in_gts
155
+
156
+ # if an anchor box is assigned to multiple gts,
157
+ # the one with the highest IoU will be selected.
158
+ overlaps_inf = torch.full_like(overlaps,
159
+ -INF).t().contiguous().view(-1)
160
+ index = candidate_idxs.view(-1)[is_pos.view(-1)]
161
+ overlaps_inf[index] = overlaps.t().contiguous().view(-1)[index]
162
+ overlaps_inf = overlaps_inf.view(num_gt, -1).t()
163
+
164
+ max_overlaps, argmax_overlaps = overlaps_inf.max(dim=1)
165
+ assigned_gt_inds[
166
+ max_overlaps != -INF] = argmax_overlaps[max_overlaps != -INF] + 1
167
+
168
+ if gt_labels is not None:
169
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
170
+ pos_inds = torch.nonzero(
171
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
172
+ if pos_inds.numel() > 0:
173
+ assigned_labels[pos_inds] = gt_labels[
174
+ assigned_gt_inds[pos_inds] - 1]
175
+ else:
176
+ assigned_labels = None
177
+ return AssignResult(
178
+ num_gt, assigned_gt_inds, max_overlaps, labels=assigned_labels)
mmdet/core/bbox/assigners/base_assigner.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
1
+ from abc import ABCMeta, abstractmethod
2
+
3
+
4
+ class BaseAssigner(metaclass=ABCMeta):
5
+ """Base assigner that assigns boxes to ground truth boxes."""
6
+
7
+ @abstractmethod
8
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
9
+ """Assign boxes to either a ground truth boxes or a negative boxes."""
mmdet/core/bbox/assigners/center_region_assigner.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ def scale_boxes(bboxes, scale):
10
+ """Expand an array of boxes by a given scale.
11
+
12
+ Args:
13
+ bboxes (Tensor): Shape (m, 4)
14
+ scale (float): The scale factor of bboxes
15
+
16
+ Returns:
17
+ (Tensor): Shape (m, 4). Scaled bboxes
18
+ """
19
+ assert bboxes.size(1) == 4
20
+ w_half = (bboxes[:, 2] - bboxes[:, 0]) * .5
21
+ h_half = (bboxes[:, 3] - bboxes[:, 1]) * .5
22
+ x_c = (bboxes[:, 2] + bboxes[:, 0]) * .5
23
+ y_c = (bboxes[:, 3] + bboxes[:, 1]) * .5
24
+
25
+ w_half *= scale
26
+ h_half *= scale
27
+
28
+ boxes_scaled = torch.zeros_like(bboxes)
29
+ boxes_scaled[:, 0] = x_c - w_half
30
+ boxes_scaled[:, 2] = x_c + w_half
31
+ boxes_scaled[:, 1] = y_c - h_half
32
+ boxes_scaled[:, 3] = y_c + h_half
33
+ return boxes_scaled
34
+
35
+
36
+ def is_located_in(points, bboxes):
37
+ """Are points located in bboxes.
38
+
39
+ Args:
40
+ points (Tensor): Points, shape: (m, 2).
41
+ bboxes (Tensor): Bounding boxes, shape: (n, 4).
42
+
43
+ Return:
44
+ Tensor: Flags indicating if points are located in bboxes, shape: (m, n).
45
+ """
46
+ assert points.size(1) == 2
47
+ assert bboxes.size(1) == 4
48
+ return (points[:, 0].unsqueeze(1) > bboxes[:, 0].unsqueeze(0)) & \
49
+ (points[:, 0].unsqueeze(1) < bboxes[:, 2].unsqueeze(0)) & \
50
+ (points[:, 1].unsqueeze(1) > bboxes[:, 1].unsqueeze(0)) & \
51
+ (points[:, 1].unsqueeze(1) < bboxes[:, 3].unsqueeze(0))
52
+
53
+
54
+ def bboxes_area(bboxes):
55
+ """Compute the area of an array of bboxes.
56
+
57
+ Args:
58
+ bboxes (Tensor): The coordinates ox bboxes. Shape: (m, 4)
59
+
60
+ Returns:
61
+ Tensor: Area of the bboxes. Shape: (m, )
62
+ """
63
+ assert bboxes.size(1) == 4
64
+ w = (bboxes[:, 2] - bboxes[:, 0])
65
+ h = (bboxes[:, 3] - bboxes[:, 1])
66
+ areas = w * h
67
+ return areas
68
+
69
+
70
+ @BBOX_ASSIGNERS.register_module()
71
+ class CenterRegionAssigner(BaseAssigner):
72
+ """Assign pixels at the center region of a bbox as positive.
73
+
74
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
75
+ indicating the ground truth index.
76
+ - -1: negative samples
77
+ - semi-positive numbers: positive sample, index (0-based) of assigned gt
78
+
79
+ Args:
80
+ pos_scale (float): Threshold within which pixels are
81
+ labelled as positive.
82
+ neg_scale (float): Threshold above which pixels are
83
+ labelled as positive.
84
+ min_pos_iof (float): Minimum iof of a pixel with a gt to be
85
+ labelled as positive. Default: 1e-2
86
+ ignore_gt_scale (float): Threshold within which the pixels
87
+ are ignored when the gt is labelled as shadowed. Default: 0.5
88
+ foreground_dominate (bool): If True, the bbox will be assigned as
89
+ positive when a gt's kernel region overlaps with another's shadowed
90
+ (ignored) region, otherwise it is set as ignored. Default to False.
91
+ """
92
+
93
+ def __init__(self,
94
+ pos_scale,
95
+ neg_scale,
96
+ min_pos_iof=1e-2,
97
+ ignore_gt_scale=0.5,
98
+ foreground_dominate=False,
99
+ iou_calculator=dict(type='BboxOverlaps2D')):
100
+ self.pos_scale = pos_scale
101
+ self.neg_scale = neg_scale
102
+ self.min_pos_iof = min_pos_iof
103
+ self.ignore_gt_scale = ignore_gt_scale
104
+ self.foreground_dominate = foreground_dominate
105
+ self.iou_calculator = build_iou_calculator(iou_calculator)
106
+
107
+ def get_gt_priorities(self, gt_bboxes):
108
+ """Get gt priorities according to their areas.
109
+
110
+ Smaller gt has higher priority.
111
+
112
+ Args:
113
+ gt_bboxes (Tensor): Ground truth boxes, shape (k, 4).
114
+
115
+ Returns:
116
+ Tensor: The priority of gts so that gts with larger priority is \
117
+ more likely to be assigned. Shape (k, )
118
+ """
119
+ gt_areas = bboxes_area(gt_bboxes)
120
+ # Rank all gt bbox areas. Smaller objects has larger priority
121
+ _, sort_idx = gt_areas.sort(descending=True)
122
+ sort_idx = sort_idx.argsort()
123
+ return sort_idx
124
+
125
+ def assign(self, bboxes, gt_bboxes, gt_bboxes_ignore=None, gt_labels=None):
126
+ """Assign gt to bboxes.
127
+
128
+ This method assigns gts to every bbox (proposal/anchor), each bbox \
129
+ will be assigned with -1, or a semi-positive number. -1 means \
130
+ negative sample, semi-positive number is the index (0-based) of \
131
+ assigned gt.
132
+
133
+ Args:
134
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
135
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
136
+ gt_bboxes_ignore (tensor, optional): Ground truth bboxes that are
137
+ labelled as `ignored`, e.g., crowd boxes in COCO.
138
+ gt_labels (tensor, optional): Label of gt_bboxes, shape (num_gts,).
139
+
140
+ Returns:
141
+ :obj:`AssignResult`: The assigned result. Note that \
142
+ shadowed_labels of shape (N, 2) is also added as an \
143
+ `assign_result` attribute. `shadowed_labels` is a tensor \
144
+ composed of N pairs of anchor_ind, class_label], where N \
145
+ is the number of anchors that lie in the outer region of a \
146
+ gt, anchor_ind is the shadowed anchor index and class_label \
147
+ is the shadowed class label.
148
+
149
+ Example:
150
+ >>> self = CenterRegionAssigner(0.2, 0.2)
151
+ >>> bboxes = torch.Tensor([[0, 0, 10, 10], [10, 10, 20, 20]])
152
+ >>> gt_bboxes = torch.Tensor([[0, 0, 10, 10]])
153
+ >>> assign_result = self.assign(bboxes, gt_bboxes)
154
+ >>> expected_gt_inds = torch.LongTensor([1, 0])
155
+ >>> assert torch.all(assign_result.gt_inds == expected_gt_inds)
156
+ """
157
+ # There are in total 5 steps in the pixel assignment
158
+ # 1. Find core (the center region, say inner 0.2)
159
+ # and shadow (the relatively ourter part, say inner 0.2-0.5)
160
+ # regions of every gt.
161
+ # 2. Find all prior bboxes that lie in gt_core and gt_shadow regions
162
+ # 3. Assign prior bboxes in gt_core with a one-hot id of the gt in
163
+ # the image.
164
+ # 3.1. For overlapping objects, the prior bboxes in gt_core is
165
+ # assigned with the object with smallest area
166
+ # 4. Assign prior bboxes with class label according to its gt id.
167
+ # 4.1. Assign -1 to prior bboxes lying in shadowed gts
168
+ # 4.2. Assign positive prior boxes with the corresponding label
169
+ # 5. Find pixels lying in the shadow of an object and assign them with
170
+ # background label, but set the loss weight of its corresponding
171
+ # gt to zero.
172
+ assert bboxes.size(1) == 4, 'bboxes must have size of 4'
173
+ # 1. Find core positive and shadow region of every gt
174
+ gt_core = scale_boxes(gt_bboxes, self.pos_scale)
175
+ gt_shadow = scale_boxes(gt_bboxes, self.neg_scale)
176
+
177
+ # 2. Find prior bboxes that lie in gt_core and gt_shadow regions
178
+ bbox_centers = (bboxes[:, 2:4] + bboxes[:, 0:2]) / 2
179
+ # The center points lie within the gt boxes
180
+ is_bbox_in_gt = is_located_in(bbox_centers, gt_bboxes)
181
+ # Only calculate bbox and gt_core IoF. This enables small prior bboxes
182
+ # to match large gts
183
+ bbox_and_gt_core_overlaps = self.iou_calculator(
184
+ bboxes, gt_core, mode='iof')
185
+ # The center point of effective priors should be within the gt box
186
+ is_bbox_in_gt_core = is_bbox_in_gt & (
187
+ bbox_and_gt_core_overlaps > self.min_pos_iof) # shape (n, k)
188
+
189
+ is_bbox_in_gt_shadow = (
190
+ self.iou_calculator(bboxes, gt_shadow, mode='iof') >
191
+ self.min_pos_iof)
192
+ # Rule out center effective positive pixels
193
+ is_bbox_in_gt_shadow &= (~is_bbox_in_gt_core)
194
+
195
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
196
+ if num_gts == 0 or num_bboxes == 0:
197
+ # If no gts exist, assign all pixels to negative
198
+ assigned_gt_ids = \
199
+ is_bbox_in_gt_core.new_zeros((num_bboxes,),
200
+ dtype=torch.long)
201
+ pixels_in_gt_shadow = assigned_gt_ids.new_empty((0, 2))
202
+ else:
203
+ # Step 3: assign a one-hot gt id to each pixel, and smaller objects
204
+ # have high priority to assign the pixel.
205
+ sort_idx = self.get_gt_priorities(gt_bboxes)
206
+ assigned_gt_ids, pixels_in_gt_shadow = \
207
+ self.assign_one_hot_gt_indices(is_bbox_in_gt_core,
208
+ is_bbox_in_gt_shadow,
209
+ gt_priority=sort_idx)
210
+
211
+ if gt_bboxes_ignore is not None and gt_bboxes_ignore.numel() > 0:
212
+ # No ground truth or boxes, return empty assignment
213
+ gt_bboxes_ignore = scale_boxes(
214
+ gt_bboxes_ignore, scale=self.ignore_gt_scale)
215
+ is_bbox_in_ignored_gts = is_located_in(bbox_centers,
216
+ gt_bboxes_ignore)
217
+ is_bbox_in_ignored_gts = is_bbox_in_ignored_gts.any(dim=1)
218
+ assigned_gt_ids[is_bbox_in_ignored_gts] = -1
219
+
220
+ # 4. Assign prior bboxes with class label according to its gt id.
221
+ assigned_labels = None
222
+ shadowed_pixel_labels = None
223
+ if gt_labels is not None:
224
+ # Default assigned label is the background (-1)
225
+ assigned_labels = assigned_gt_ids.new_full((num_bboxes, ), -1)
226
+ pos_inds = torch.nonzero(
227
+ assigned_gt_ids > 0, as_tuple=False).squeeze()
228
+ if pos_inds.numel() > 0:
229
+ assigned_labels[pos_inds] = gt_labels[assigned_gt_ids[pos_inds]
230
+ - 1]
231
+ # 5. Find pixels lying in the shadow of an object
232
+ shadowed_pixel_labels = pixels_in_gt_shadow.clone()
233
+ if pixels_in_gt_shadow.numel() > 0:
234
+ pixel_idx, gt_idx =\
235
+ pixels_in_gt_shadow[:, 0], pixels_in_gt_shadow[:, 1]
236
+ assert (assigned_gt_ids[pixel_idx] != gt_idx).all(), \
237
+ 'Some pixels are dually assigned to ignore and gt!'
238
+ shadowed_pixel_labels[:, 1] = gt_labels[gt_idx - 1]
239
+ override = (
240
+ assigned_labels[pixel_idx] == shadowed_pixel_labels[:, 1])
241
+ if self.foreground_dominate:
242
+ # When a pixel is both positive and shadowed, set it as pos
243
+ shadowed_pixel_labels = shadowed_pixel_labels[~override]
244
+ else:
245
+ # When a pixel is both pos and shadowed, set it as shadowed
246
+ assigned_labels[pixel_idx[override]] = -1
247
+ assigned_gt_ids[pixel_idx[override]] = 0
248
+
249
+ assign_result = AssignResult(
250
+ num_gts, assigned_gt_ids, None, labels=assigned_labels)
251
+ # Add shadowed_labels as assign_result property. Shape: (num_shadow, 2)
252
+ assign_result.set_extra_property('shadowed_labels',
253
+ shadowed_pixel_labels)
254
+ return assign_result
255
+
256
+ def assign_one_hot_gt_indices(self,
257
+ is_bbox_in_gt_core,
258
+ is_bbox_in_gt_shadow,
259
+ gt_priority=None):
260
+ """Assign only one gt index to each prior box.
261
+
262
+ Gts with large gt_priority are more likely to be assigned.
263
+
264
+ Args:
265
+ is_bbox_in_gt_core (Tensor): Bool tensor indicating the bbox center
266
+ is in the core area of a gt (e.g. 0-0.2).
267
+ Shape: (num_prior, num_gt).
268
+ is_bbox_in_gt_shadow (Tensor): Bool tensor indicating the bbox
269
+ center is in the shadowed area of a gt (e.g. 0.2-0.5).
270
+ Shape: (num_prior, num_gt).
271
+ gt_priority (Tensor): Priorities of gts. The gt with a higher
272
+ priority is more likely to be assigned to the bbox when the bbox
273
+ match with multiple gts. Shape: (num_gt, ).
274
+
275
+ Returns:
276
+ tuple: Returns (assigned_gt_inds, shadowed_gt_inds).
277
+
278
+ - assigned_gt_inds: The assigned gt index of each prior bbox \
279
+ (i.e. index from 1 to num_gts). Shape: (num_prior, ).
280
+ - shadowed_gt_inds: shadowed gt indices. It is a tensor of \
281
+ shape (num_ignore, 2) with first column being the \
282
+ shadowed prior bbox indices and the second column the \
283
+ shadowed gt indices (1-based).
284
+ """
285
+ num_bboxes, num_gts = is_bbox_in_gt_core.shape
286
+
287
+ if gt_priority is None:
288
+ gt_priority = torch.arange(
289
+ num_gts, device=is_bbox_in_gt_core.device)
290
+ assert gt_priority.size(0) == num_gts
291
+ # The bigger gt_priority, the more preferable to be assigned
292
+ # The assigned inds are by default 0 (background)
293
+ assigned_gt_inds = is_bbox_in_gt_core.new_zeros((num_bboxes, ),
294
+ dtype=torch.long)
295
+ # Shadowed bboxes are assigned to be background. But the corresponding
296
+ # label is ignored during loss calculation, which is done through
297
+ # shadowed_gt_inds
298
+ shadowed_gt_inds = torch.nonzero(is_bbox_in_gt_shadow, as_tuple=False)
299
+ if is_bbox_in_gt_core.sum() == 0: # No gt match
300
+ shadowed_gt_inds[:, 1] += 1 # 1-based. For consistency issue
301
+ return assigned_gt_inds, shadowed_gt_inds
302
+
303
+ # The priority of each prior box and gt pair. If one prior box is
304
+ # matched bo multiple gts. Only the pair with the highest priority
305
+ # is saved
306
+ pair_priority = is_bbox_in_gt_core.new_full((num_bboxes, num_gts),
307
+ -1,
308
+ dtype=torch.long)
309
+
310
+ # Each bbox could match with multiple gts.
311
+ # The following codes deal with this situation
312
+ # Matched bboxes (to any gt). Shape: (num_pos_anchor, )
313
+ inds_of_match = torch.any(is_bbox_in_gt_core, dim=1)
314
+ # The matched gt index of each positive bbox. Length >= num_pos_anchor
315
+ # , since one bbox could match multiple gts
316
+ matched_bbox_gt_inds = torch.nonzero(
317
+ is_bbox_in_gt_core, as_tuple=False)[:, 1]
318
+ # Assign priority to each bbox-gt pair.
319
+ pair_priority[is_bbox_in_gt_core] = gt_priority[matched_bbox_gt_inds]
320
+ _, argmax_priority = pair_priority[inds_of_match].max(dim=1)
321
+ assigned_gt_inds[inds_of_match] = argmax_priority + 1 # 1-based
322
+ # Zero-out the assigned anchor box to filter the shadowed gt indices
323
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 0
324
+ # Concat the shadowed indices due to overlapping with that out side of
325
+ # effective scale. shape: (total_num_ignore, 2)
326
+ shadowed_gt_inds = torch.cat(
327
+ (shadowed_gt_inds, torch.nonzero(
328
+ is_bbox_in_gt_core, as_tuple=False)),
329
+ dim=0)
330
+ # `is_bbox_in_gt_core` should be changed back to keep arguments intact.
331
+ is_bbox_in_gt_core[inds_of_match, argmax_priority] = 1
332
+ # 1-based shadowed gt indices, to be consistent with `assigned_gt_inds`
333
+ if shadowed_gt_inds.numel() > 0:
334
+ shadowed_gt_inds[:, 1] += 1
335
+ return assigned_gt_inds, shadowed_gt_inds
mmdet/core/bbox/assigners/grid_assigner.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ from ..builder import BBOX_ASSIGNERS
4
+ from ..iou_calculators import build_iou_calculator
5
+ from .assign_result import AssignResult
6
+ from .base_assigner import BaseAssigner
7
+
8
+
9
+ @BBOX_ASSIGNERS.register_module()
10
+ class GridAssigner(BaseAssigner):
11
+ """Assign a corresponding gt bbox or background to each bbox.
12
+
13
+ Each proposals will be assigned with `-1`, `0`, or a positive integer
14
+ indicating the ground truth index.
15
+
16
+ - -1: don't care
17
+ - 0: negative sample, no assigned gt
18
+ - positive integer: positive sample, index (1-based) of assigned gt
19
+
20
+ Args:
21
+ pos_iou_thr (float): IoU threshold for positive bboxes.
22
+ neg_iou_thr (float or tuple): IoU threshold for negative bboxes.
23
+ min_pos_iou (float): Minimum iou for a bbox to be considered as a
24
+ positive bbox. Positive samples can have smaller IoU than
25
+ pos_iou_thr due to the 4th step (assign max IoU sample to each gt).
26
+ gt_max_assign_all (bool): Whether to assign all bboxes with the same
27
+ highest overlap with some gt to that gt.
28
+ """
29
+
30
+ def __init__(self,
31
+ pos_iou_thr,
32
+ neg_iou_thr,
33
+ min_pos_iou=.0,
34
+ gt_max_assign_all=True,
35
+ iou_calculator=dict(type='BboxOverlaps2D')):
36
+ self.pos_iou_thr = pos_iou_thr
37
+ self.neg_iou_thr = neg_iou_thr
38
+ self.min_pos_iou = min_pos_iou
39
+ self.gt_max_assign_all = gt_max_assign_all
40
+ self.iou_calculator = build_iou_calculator(iou_calculator)
41
+
42
+ def assign(self, bboxes, box_responsible_flags, gt_bboxes, gt_labels=None):
43
+ """Assign gt to bboxes. The process is very much like the max iou
44
+ assigner, except that positive samples are constrained within the cell
45
+ that the gt boxes fell in.
46
+
47
+ This method assign a gt bbox to every bbox (proposal/anchor), each bbox
48
+ will be assigned with -1, 0, or a positive number. -1 means don't care,
49
+ 0 means negative sample, positive number is the index (1-based) of
50
+ assigned gt.
51
+ The assignment is done in following steps, the order matters.
52
+
53
+ 1. assign every bbox to -1
54
+ 2. assign proposals whose iou with all gts <= neg_iou_thr to 0
55
+ 3. for each bbox within a cell, if the iou with its nearest gt >
56
+ pos_iou_thr and the center of that gt falls inside the cell,
57
+ assign it to that bbox
58
+ 4. for each gt bbox, assign its nearest proposals within the cell the
59
+ gt bbox falls in to itself.
60
+
61
+ Args:
62
+ bboxes (Tensor): Bounding boxes to be assigned, shape(n, 4).
63
+ box_responsible_flags (Tensor): flag to indicate whether box is
64
+ responsible for prediction, shape(n, )
65
+ gt_bboxes (Tensor): Groundtruth boxes, shape (k, 4).
66
+ gt_labels (Tensor, optional): Label of gt_bboxes, shape (k, ).
67
+
68
+ Returns:
69
+ :obj:`AssignResult`: The assign result.
70
+ """
71
+ num_gts, num_bboxes = gt_bboxes.size(0), bboxes.size(0)
72
+
73
+ # compute iou between all gt and bboxes
74
+ overlaps = self.iou_calculator(gt_bboxes, bboxes)
75
+
76
+ # 1. assign -1 by default
77
+ assigned_gt_inds = overlaps.new_full((num_bboxes, ),
78
+ -1,
79
+ dtype=torch.long)
80
+
81
+ if num_gts == 0 or num_bboxes == 0:
82
+ # No ground truth or boxes, return empty assignment
83
+ max_overlaps = overlaps.new_zeros((num_bboxes, ))
84
+ if num_gts == 0:
85
+ # No truth, assign everything to background
86
+ assigned_gt_inds[:] = 0
87
+ if gt_labels is None:
88
+ assigned_labels = None
89
+ else:
90
+ assigned_labels = overlaps.new_full((num_bboxes, ),
91
+ -1,
92
+ dtype=torch.long)
93
+ return AssignResult(
94
+ num_gts,
95
+ assigned_gt_inds,
96
+ max_overlaps,
97
+ labels=assigned_labels)
98
+
99
+ # 2. assign negative: below
100
+ # for each anchor, which gt best overlaps with it
101
+ # for each anchor, the max iou of all gts
102
+ # shape of max_overlaps == argmax_overlaps == num_bboxes
103
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
104
+
105
+ if isinstance(self.neg_iou_thr, float):
106
+ assigned_gt_inds[(max_overlaps >= 0)
107
+ & (max_overlaps <= self.neg_iou_thr)] = 0
108
+ elif isinstance(self.neg_iou_thr, (tuple, list)):
109
+ assert len(self.neg_iou_thr) == 2
110
+ assigned_gt_inds[(max_overlaps > self.neg_iou_thr[0])
111
+ & (max_overlaps <= self.neg_iou_thr[1])] = 0
112
+
113
+ # 3. assign positive: falls into responsible cell and above
114
+ # positive IOU threshold, the order matters.
115
+ # the prior condition of comparision is to filter out all
116
+ # unrelated anchors, i.e. not box_responsible_flags
117
+ overlaps[:, ~box_responsible_flags.type(torch.bool)] = -1.
118
+
119
+ # calculate max_overlaps again, but this time we only consider IOUs
120
+ # for anchors responsible for prediction
121
+ max_overlaps, argmax_overlaps = overlaps.max(dim=0)
122
+
123
+ # for each gt, which anchor best overlaps with it
124
+ # for each gt, the max iou of all proposals
125
+ # shape of gt_max_overlaps == gt_argmax_overlaps == num_gts
126
+ gt_max_overlaps, gt_argmax_overlaps = overlaps.max(dim=1)
127
+
128
+ pos_inds = (max_overlaps >
129
+ self.pos_iou_thr) & box_responsible_flags.type(torch.bool)
130
+ assigned_gt_inds[pos_inds] = argmax_overlaps[pos_inds] + 1
131
+
132
+ # 4. assign positive to max overlapped anchors within responsible cell
133
+ for i in range(num_gts):
134
+ if gt_max_overlaps[i] > self.min_pos_iou:
135
+ if self.gt_max_assign_all:
136
+ max_iou_inds = (overlaps[i, :] == gt_max_overlaps[i]) & \
137
+ box_responsible_flags.type(torch.bool)
138
+ assigned_gt_inds[max_iou_inds] = i + 1
139
+ elif box_responsible_flags[gt_argmax_overlaps[i]]:
140
+ assigned_gt_inds[gt_argmax_overlaps[i]] = i + 1
141
+
142
+ # assign labels of positive anchors
143
+ if gt_labels is not None:
144
+ assigned_labels = assigned_gt_inds.new_full((num_bboxes, ), -1)
145
+ pos_inds = torch.nonzero(
146
+ assigned_gt_inds > 0, as_tuple=False).squeeze()
147
+ if pos_inds.numel() > 0:
148
+ assigned_labels[pos_inds] = gt_labels[
149
+ assigned_gt_inds[pos_inds] - 1]
150
+
151
+ else:
152
+ assigned_labels = None
153
+
154
+ return AssignResult(
155
+ num_gts, assigned_gt_inds, max_overlaps, labels=assigned_labels)