File size: 3,463 Bytes
be2715b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
from mmdet.apis import set_random_seed
from mmcv import Config

def get_config(base_directory='.'):
  print ("Using base_config_track")
  cfg = Config.fromfile(base_directory + '/configs/faster_rcnn/faster_rcnn_r50_fpn_1x_coco.py')
  #print(cfg.pretty_text)

  cfg.classes = ("Aortic_enlargement", "Atelectasis", "Calcification", "Cardiomegaly", "Consolidation", "ILD", "Infiltration", "Lung_Opacity", "Nodule/Mass", "Other_lesion", "Pleural_effusion", "Pleural_thickening", "Pneumothorax", "Pulmonary_fibrosis")

  cfg.data.train.img_prefix = base_directory + '/data/'
  cfg.data.train.ann_file = base_directory + '/data/train_annotations.json'
  cfg.data.train.classes = cfg.classes
  cfg.data.train.type='CocoDatasetSubset'

  img_norm_cfg = dict(
    mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)

  albu_train_transforms = [
    dict(
      type='RandomSizedBBoxSafeCrop',
      height=512,
      width=512,
      erosion_rate=0.2),
  ]

  cfg.data.train.pipeline = [
    dict(type='LoadImageFromFile'),
    dict(type='LoadAnnotations', with_bbox=True),
    dict(type='Resize', img_scale=(512, 512), keep_ratio=True),
    dict(type='Pad', size_divisor=32),
    dict(type='RandomFlip', flip_ratio=0.5),
    dict(
      type='Albu',
      transforms=albu_train_transforms,
      bbox_params=dict(
          type='BboxParams',
          format='pascal_voc',
          label_fields=['gt_labels'],
          min_visibility=0.0,
          filter_lost_elements=True),
      keymap={
          'img': 'image',
          'gt_bboxes': 'bboxes'
      },
      update_pad_shape=False,
      skip_img_without_anno=True),
    dict(type='Normalize', **img_norm_cfg),
    dict(type='DefaultFormatBundle'),
    dict(type='Collect', keys=['img', 'gt_bboxes', 'gt_labels']),
  ]


  cfg.data.train = dict(
    type='ClassBalancedDataset',
    oversample_thr=0.4,
    dataset=cfg.data.train
  )

  cfg.data.val.img_prefix = base_directory + '/data/'
  cfg.data.val.ann_file = base_directory + '/data/valid_annotations.json' 
  cfg.data.val.classes = cfg.classes
  cfg.data.val.type='CocoDataset'

  cfg.data.test.img_prefix = base_directory + '/data/'
  cfg.data.test.ann_file = base_directory + '/data/test_annotations.json'
  cfg.data.test.classes = cfg.classes
  cfg.data.test.type='CocoDataset'

  cfg.model.roi_head.bbox_head.num_classes = 14

  cfg.optimizer.lr = 0.02 / 8
  cfg.lr_config.warmup = None
  cfg.log_config.interval = 10

  # We can set the checkpoint saving interval to reduce the storage cost
  cfg.checkpoint_config.interval = 1

  # Set seed thus the results are more reproducible
  cfg.seed = 1
  set_random_seed(1, deterministic=False)
  cfg.gpu_ids = range(1)

  # we can use here mask_rcnn.
  # cfg.load_from = 'https://download.openmmlab.com/mmdetection/v2.0/faster_rcnn/faster_rcnn_r50_fpn_1x_coco/faster_rcnn_r50_fpn_1x_coco_20200130-047c8118.pth'
  cfg.work_dir = "../trained_weights"


  # One Epoch takes around 18 mins
  cfg.total_epochs = 30
  cfg.runner.max_epochs = 30

  cfg.data.samples_per_gpu = 6

  cfg.log_config = dict(  # config to register logger hook
    interval=50,  # Interval to print the log
    hooks=[
        dict(type='TensorboardLoggerHook'),  # The Tensorboard logger is also supported
        dict(type='TextLoggerHook')
    ])  # The logger used to record the training process.

  cfg.workflow = [('train', 1), ('val', 1)]
  cfg.evaluation=dict(classwise=True, metric='bbox')

  return cfg