Pranjal2041 commited on
Commit
970a7a2
1 Parent(s): 031b2c5

Initial demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +3 -0
  2. .gitignore +2 -0
  3. DenseMammogram/.gitignore +2 -0
  4. DenseMammogram/README.md +54 -0
  5. DenseMammogram/advanced_config.py +36 -0
  6. DenseMammogram/advanced_logger.py +48 -0
  7. DenseMammogram/all_graphs.py +156 -0
  8. DenseMammogram/auc_by_pranjal.py +120 -0
  9. DenseMammogram/dataloaders.py +259 -0
  10. DenseMammogram/detection/README.md +81 -0
  11. DenseMammogram/detection/coco_eval.py +191 -0
  12. DenseMammogram/detection/coco_utils.py +249 -0
  13. DenseMammogram/detection/engine.py +276 -0
  14. DenseMammogram/detection/group_by_aspect_ratio.py +196 -0
  15. DenseMammogram/detection/presets.py +47 -0
  16. DenseMammogram/detection/train.py +269 -0
  17. DenseMammogram/detection/transforms.py +283 -0
  18. DenseMammogram/detection/utils.py +282 -0
  19. DenseMammogram/ensemble_boxes/__init__.py +9 -0
  20. DenseMammogram/ensemble_boxes/ensemble_boxes_nms.py +249 -0
  21. DenseMammogram/ensemble_boxes/ensemble_boxes_nmw.py +202 -0
  22. DenseMammogram/ensemble_boxes/ensemble_boxes_wbf.py +269 -0
  23. DenseMammogram/ensemble_boxes/ensemble_boxes_wbf_3d.py +222 -0
  24. DenseMammogram/experimenter.py +213 -0
  25. DenseMammogram/froc_by_pranjal.py +236 -0
  26. DenseMammogram/geenerate_aiims.py +61 -0
  27. DenseMammogram/geenerate_ddsm_preds.py +61 -0
  28. DenseMammogram/geenerate_inbreast_preds.py +57 -0
  29. DenseMammogram/geenerate_irch.py +62 -0
  30. DenseMammogram/merge_predictions.py +152 -0
  31. DenseMammogram/model_utils.py +83 -0
  32. DenseMammogram/models.py +201 -0
  33. DenseMammogram/plot_froc.py +43 -0
  34. DenseMammogram/requirements.txt +11 -0
  35. DenseMammogram/train_bilateral.py +47 -0
  36. DenseMammogram/train_frcnn.py +34 -0
  37. DenseMammogram/utils.py +41 -0
  38. app.py +117 -0
  39. img_out1.jpg +3 -0
  40. img_out2.jpg +3 -0
  41. model.py +57 -0
  42. pretrained_models/AIIMS_C1/frcnn_models/frcnn_model.pth +3 -0
  43. pretrained_models/AIIMS_C2/frcnn_models/frcnn_model.pth +3 -0
  44. pretrained_models/AIIMS_C3/frcnn_models/frcnn_model.pth +3 -0
  45. pretrained_models/AIIMS_C4/frcnn_models/frcnn_model.pth +3 -0
  46. pretrained_models/AIIMS_T1/frcnn_models/frcnn_model.pth +3 -0
  47. pretrained_models/AIIMS_T2/frcnn_models/frcnn_model.pth +3 -0
  48. pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth +3 -0
  49. pretrained_models/frcnn/frcnn_models/frcnn_model.pth +3 -0
  50. requirements.txt +12 -0
.gitattributes CHANGED
@@ -6,6 +6,7 @@
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
 
9
  *.joblib filter=lfs diff=lfs merge=lfs -text
10
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
@@ -32,3 +33,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.zip filter=lfs diff=lfs merge=lfs -text
33
  *.zst filter=lfs diff=lfs merge=lfs -text
34
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
6
  *.ftz filter=lfs diff=lfs merge=lfs -text
7
  *.gz filter=lfs diff=lfs merge=lfs -text
8
  *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.jpg filter=lfs diff=lfs merge=lfs -text
10
  *.joblib filter=lfs diff=lfs merge=lfs -text
11
  *.lfs.* filter=lfs diff=lfs merge=lfs -text
12
  *.mlmodel filter=lfs diff=lfs merge=lfs -text
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ pretrained_models filter=lfs diff=lfs merge=lfs -text
37
+ sample_images filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ Demo.ipynb
2
+ __pycache__
DenseMammogram/.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ pretrained_models
2
+ __pycache__
DenseMammogram/README.md ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Deep Learning for Detection of Iso-Sense, Obscure Masses in Mammographically Dense Breasts
2
+ [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/)
3
+
4
+ ## Introduction
5
+ Deep Learning for Detection of Iso-Sense, Obscure Masses in Mammographically Dense Breasts is a paper on object detection method for finding malignant masses in breast mammograms. Our model is particularly useful for dense breasts and iso-dense and obscure masses. In this paper we have included code and pretrained weights for the paper along with all the scripts to replicate numbers in the paper(Our private dataset is not included).
6
+
7
+ ## Getting Started
8
+
9
+
10
+ First clone the repo:
11
+ ```bash
12
+ git clone https://github.com/Pranjal2041/DenseMammograms.git
13
+ ```
14
+
15
+ Next setup the enviornment using `conda` or `virtualenv`:
16
+ ```bash
17
+ 1. conda create -n densebreast python=3.7
18
+ conda activate densebreast
19
+ pip install -r requirements.txt
20
+
21
+ or
22
+
23
+ 2. python -m venv densebreast
24
+ source densebreast/bin/activate
25
+ pip install -r requirements.txt
26
+ ```
27
+
28
+ ## Pretrained Weights
29
+
30
+ You can download the pretrained models from this [url](https://csciitd-my.sharepoint.com/:f:/g/personal/cs5190443_iitd_ac_in/ElTbduIuI49EougSH05Tb4IBhbc5gXCrlok0X_xvAI196g?e=Ss2eS1) in the current directory.
31
+ <br>
32
+
33
+ ## Running the Code
34
+
35
+ To generate predictions and FROC graphs using the pretrained models, run:
36
+ `python all_graphs.py`
37
+
38
+ For running individual models on other datasets, geenerate_{dataset}_preds.py have been provided.
39
+ For example to run predictions on inbreast, run:
40
+ `python geenerate_inbreast_preds.py`
41
+
42
+
43
+ ## Demo
44
+
45
+ You can either use **Google Colab Demo** or **Huggingface demo**
46
+
47
+ ## Citation
48
+
49
+ Details Coming Soon!
50
+
51
+ ## License
52
+
53
+ TODO: Add License
54
+
DenseMammogram/advanced_config.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ class AdvancedConfig:
5
+
6
+ def save(self, file):
7
+ os.makedirs(os.path.split(file)[0], exist_ok=True)
8
+ json.dump(self.config, open(file, 'w'), indent=4)
9
+
10
+ def read_cfg(self, file):
11
+ # Its a json file with comments
12
+ new_lines = []
13
+ for line in open(file).readlines():
14
+ if line.find("#")!=-1:
15
+ new_lines.append(line[:line.find("#")])
16
+ else:
17
+ new_lines.append(line)
18
+ return json.loads('\n'.join(new_lines))
19
+
20
+
21
+ def merge_config(self, cfg_dict, base_dict):
22
+ for key in cfg_dict:
23
+ if key not in base_dict:
24
+ # Strange, raise an error
25
+ raise Exception(f'Key {key} not found in base config')
26
+ if isinstance(cfg_dict[key], dict):
27
+ base_dict[key] = self.merge_config(cfg_dict[key], base_dict[key])
28
+ else:
29
+ base_dict[key] = cfg_dict[key]
30
+ return base_dict
31
+
32
+ def __init__(self, file, base_file = 'configs/default.cfg') -> None:
33
+ self.default_config = self.read_cfg(base_file)
34
+ self.new_config = self.read_cfg(file)
35
+ self.config = self.merge_config(self.new_config, self.default_config)
36
+
DenseMammogram/advanced_logger.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # An Advanced Logger class which writes data in a well formatted manner
2
+ # to files based on different priorities.
3
+
4
+ from enum import Enum
5
+ import os
6
+ import datetime
7
+ import time
8
+
9
+ class LogPriority(Enum):
10
+ """
11
+ Enum class for different log priorities.
12
+ """
13
+ LOW = 0
14
+ MEDIUM = 1
15
+ HIGH = 2
16
+ STATS = 3
17
+
18
+ class AdvancedLogger:
19
+
20
+ def __init__(self, base_dir):
21
+ self.base_dir = base_dir
22
+ self.files = []
23
+ self.file_names = []
24
+ for p in LogPriority:
25
+ self.file_names.append(os.path.join(self.base_dir, f'Log_{p.name}' + '.log'))
26
+ self.files.append(open(self.file_names[-1], 'w'))
27
+ self.last_log_time = -1
28
+
29
+ def flush(self):
30
+ for f in self.files:
31
+ f.close()
32
+ for i in range(len(self.files)):
33
+ self.files[i] = open(self.file_names[i], 'a')
34
+
35
+ def log(self, *args, priority = LogPriority.LOW):
36
+ to_log = ' '.join(map(str, args))
37
+ if priority.value <= LogPriority.MEDIUM.value:
38
+ # Add current time to to_log
39
+ now = datetime.datetime.now()
40
+ to_log = f'[{now.strftime("%H:%M:%S")}]: {to_log}'
41
+ print(to_log)
42
+ for p in range(priority.value+1):
43
+ self.files[p].write(to_log + '\n')
44
+
45
+ # If time - last_log_time is greater than 10s or Priority is HIGH or above close the file and re-open in append mode
46
+ if time.time() - self.last_log_time > 10 or priority.value >= LogPriority.HIGH.value:
47
+ self.flush()
48
+ self.last_log_time = time.time()
DenseMammogram/all_graphs.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join
3
+ from merge_predictions import get_image_dict, apply_merge
4
+ from froc_by_pranjal import calc_froc_from_dict, pretty_print_fps
5
+ import numpy as np
6
+ import matplotlib.pyplot as plt
7
+
8
+
9
+ OUT_DIR = 'euro_results_auto'
10
+ numbers_dir = os.path.join(OUT_DIR, 'numbers')
11
+ graphs_dir = os.path.join(OUT_DIR, 'graphs')
12
+
13
+ BASE_FOLDER = '../bilateral_new/MammoDatasets'
14
+
15
+ MIN_CLIP_FPI = 0.02
16
+ def plot_froc(input_files, save_file, TITLE = 'FRCNN vs BILATERAL FROC', SHOW = False, CLIP_FPI = 1.2):
17
+ for file in input_files:
18
+ lines = open(file).readlines()
19
+ x = np.array([float(line.split()[0]) for line in lines])
20
+ y = np.array([float(line.split()[1]) for line in lines])
21
+ y = y[x<CLIP_FPI]
22
+ x = x[x<CLIP_FPI]
23
+ y = y[MIN_CLIP_FPI<x]
24
+ x = x[MIN_CLIP_FPI<x]
25
+ plt.plot(x, y, label = input_files[file])
26
+ plt.legend()
27
+
28
+ plt.title(TITLE)
29
+ plt.xlabel('Average False Positive Per Image')
30
+ plt.ylabel('Sensitivity')
31
+
32
+ if SHOW:
33
+ plt.show()
34
+ plt.savefig(save_file)
35
+ plt.clf()
36
+
37
+
38
+ dsets = [('AIIMS_highres_reliable', 'AIIMS'), ('IRCHVal', 'IRCHVal')]
39
+ dsets = dsets[1:]
40
+ for dset in dsets:
41
+ test_splits = ['test_2', 'test_dense', 'test_iso'][::-1]
42
+ for test_split in test_splits:
43
+ main_dataset = join(BASE_FOLDER, dset[0], test_split)
44
+
45
+ contrast_datasets = [join(BASE_FOLDER,f'{dset[1]}_C{i+1}',test_split) for i in range(4)]
46
+ threshold_datasets = [join(BASE_FOLDER,f'{dset[1]}_T{i+1}',test_split) for i in range(2)]
47
+ frcnn_preds = 'preds_frcnn_frcnn'
48
+ contrast_preds = [
49
+ 'preds_frcnn_AIIMS_C1',
50
+ 'preds_frcnn_AIIMS_C2',
51
+ 'preds_frcnn_AIIMS_C3',
52
+ 'preds_frcnn_AIIMS_C4',
53
+ ]
54
+ bilateral_preds = 'preds_bilateral_BILATERAL'
55
+ threshold_preds = [
56
+ 'preds_frcnn_AIIMS_T1',
57
+ 'preds_frcnn_AIIMS_T2',
58
+ ]
59
+
60
+ input_files = []
61
+ dataset_paths = [join(main_dataset, '{0}', frcnn_preds)]
62
+ dataset_paths +=[join(dset, '{0}', preds) for (dset,preds) in zip(contrast_datasets, contrast_preds)]
63
+ dataset_paths +=[join(dset, '{0}', preds) for (dset,preds) in zip(threshold_datasets, threshold_preds)]
64
+ dataset_paths +=[join(main_dataset, '{0}', bilateral_preds)]
65
+
66
+
67
+ CONFIGS = {
68
+ 'Baseline' : ('Baseline Model', [0]),
69
+ 'Bilateral' : ('Bilateral Model', [7]),
70
+ 'Contrast' : ('CABD Model', [0,1,2,3,4]),
71
+ 'Threshold' : ('TI Model', [0,5,6]),
72
+ 'Proposed' : ('Proposed Model', [1,2,3,4,5,6,7])
73
+ }
74
+
75
+ # Now handle the directories
76
+ num_dir = os.path.join(numbers_dir, dset[1], test_split)
77
+ os.makedirs(num_dir, exist_ok=True)
78
+
79
+
80
+ for config in CONFIGS:
81
+ title = CONFIGS[config][0]
82
+ allowed = CONFIGS[config][1]
83
+
84
+ weight_map = {
85
+ 0 : 1.,
86
+ 1 : 1,
87
+ 2 : 1.,
88
+ 3 : 1.,
89
+ 4 : .5, # C4
90
+ 5 : 0.5,
91
+ 6 : 0.5,
92
+ 7 : 1
93
+ }
94
+
95
+ weights = [weight_map[x] for x in allowed]
96
+
97
+ # generate the required mp dicts
98
+ def c2_manp(preds):
99
+ preds = list(filter(lambda x: x[0]>0.85,preds)) # keep preds lower than 0.6 confidence
100
+ return preds
101
+
102
+ def c3_manp(preds):
103
+ preds = list(filter(lambda x: x[0]>0.85,preds)) # keep preds lower than 0.6 confidence
104
+ return preds
105
+
106
+ def t1_manp(preds):
107
+ preds = list(filter(lambda x: x[0]>0.6,preds)) # keep preds lower than 0.6 confidence
108
+ return preds
109
+
110
+ t2_manp = t1_manp
111
+ mp_dict = {
112
+ f'{dset[1]}_C2' : c2_manp,
113
+ f'{dset[1]}_C3' : c3_manp,
114
+ f'{dset[1]}_T1' : t1_manp,
115
+ f'{dset[1]}_T2' : t2_manp,
116
+ f'{dset[1]}_C4' : c3_manp
117
+ }
118
+
119
+ image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = False, acr_cat = None, mp_dict = mp_dict)
120
+ image_dict = apply_merge(image_dict, METHOD = 'nms', weights= weights, conf_type='absent_model_aware_avg')
121
+
122
+
123
+ senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.], save_to = os.path.join(num_dir, f'{title}.txt'))
124
+
125
+
126
+ # Lets plot now
127
+
128
+ GRAPHS = [
129
+ ('Bilateral','Baseline'),
130
+ ('Contrast','Baseline'),
131
+ ('Threshold','Baseline'),
132
+ ('Proposed','Baseline'),
133
+ ('Proposed', 'Bilateral'),
134
+ ('Proposed', 'Contrast'),
135
+ ('Proposed', 'Threshold'),
136
+ ]
137
+
138
+
139
+ # Now handle the directories
140
+ graph_dir = os.path.join(graphs_dir, dset[1], test_split)
141
+ os.makedirs(graph_dir, exist_ok=True)
142
+
143
+ for graph in GRAPHS:
144
+ if graph[0] not in CONFIGS or graph[1] not in CONFIGS: continue
145
+ file_name1 = f'{CONFIGS[graph[0]][0]}.txt'
146
+ file_name2 = f'{CONFIGS[graph[1]][0]}.txt'
147
+
148
+ title1 = CONFIGS[graph[0]][0]
149
+ title2 = CONFIGS[graph[1]][0]
150
+
151
+ plot_froc({
152
+ join(num_dir, file_name1): title1,
153
+ join(num_dir, file_name2) : title2,
154
+ }, join(graph_dir,f'{title1}_vs_{title2}.png'),f'{title1} vs {title2} FROC', CLIP_FPI = 0.3 if dset[0] == 'IRCHVal' else 0.8)
155
+
156
+
DenseMammogram/auc_by_pranjal.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from os.path import join
3
+ import glob
4
+ from sklearn.metrics import roc_auc_score, roc_curve
5
+ import sys
6
+
7
+ def file_to_score(file):
8
+ try:
9
+ content = open(file, 'r').readlines()
10
+ st = 0
11
+ if len(content) == 0:
12
+ # Empty File Should Return []
13
+ return 0.
14
+ if content[0].split()[0].isalpha():
15
+ st = 1
16
+ return max([float(line.split()[st]) for line in content])
17
+ except FileNotFoundError:
18
+ print(f'No Corresponding Box Found for file {file}, using [] as preds')
19
+ return []
20
+ except Exception as e:
21
+ print('Some Error',e)
22
+ return []
23
+
24
+ # Create the image dict
25
+ def generate_image_dict(preds_folder_name='preds_42',
26
+ root_fol='/home/krithika_1/densebreeast_datasets/AIIMS_C1',
27
+ mal_path=None, ben_path=None, gt_path=None,
28
+ mal_img_path = None, ben_img_path = None
29
+ ):
30
+
31
+ mal_path = join(root_fol, mal_path) if mal_path else join(
32
+ root_fol, 'mal', preds_folder_name)
33
+ ben_path = join(root_fol, ben_path) if ben_path else join(
34
+ root_fol, 'ben', preds_folder_name)
35
+ mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join(
36
+ root_fol, 'mal', 'images')
37
+ ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join(
38
+ root_fol, 'ben', 'images')
39
+ gt_path = join(root_fol, gt_path) if gt_path else join(
40
+ root_fol, 'mal', 'gt')
41
+
42
+
43
+ '''
44
+ image_dict structure:
45
+ 'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : score}
46
+ '''
47
+ image_dict = dict()
48
+
49
+ # GT Might be sightly different from images, therefore we will index gts based on
50
+ # the images folder instead.
51
+ for file in os.listdir(mal_img_path):
52
+ # for file in glob.glob(join(gt_path, '*.txt')):
53
+ if not file.endswith('.png'):
54
+ continue
55
+ file = file[:-4] + '.txt'
56
+ file = join(gt_path, file)
57
+ key = os.path.split(file)[-1][:-4]
58
+ image_dict[key] = dict()
59
+ image_dict[key]['gt'] = 1.
60
+ image_dict[key]['preds'] = 0.
61
+
62
+ for file in glob.glob(join(mal_path, '*.txt')):
63
+ key = os.path.split(file)[-1][:-4]
64
+ assert key in image_dict
65
+ image_dict[key]['preds'] = file_to_score(file)
66
+
67
+ for file in os.listdir(ben_img_path):
68
+ # for file in glob.glob(join(ben_path, '*.txt')):
69
+ if not file.endswith('.png'):
70
+ continue
71
+
72
+ file = file[:-4] + '.txt'
73
+ file = join(ben_path, file)
74
+ key = os.path.split(file)[-1][:-4]
75
+ # if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC':
76
+ # continue
77
+ if key in image_dict:
78
+ print(key)
79
+ print('SHIT')
80
+ continue
81
+ # assert key not in image_dict
82
+ image_dict[key] = dict()
83
+ image_dict[key]['preds'] = file_to_score(file)
84
+ image_dict[key]['gt'] = 0.
85
+ return image_dict
86
+
87
+ def get_auc_score_from_imdict(image_dict):
88
+ keys = list(image_dict.keys())
89
+ y = [image_dict[k]['gt']for k in keys]
90
+ preds = [image_dict[k]['preds']for k in keys]
91
+ return roc_auc_score(y, preds)
92
+
93
+ def get_accuracy_from_imdict(image_dict, thresh = 0.3):
94
+ keys = list(image_dict.keys())
95
+ ys = [image_dict[k]['gt']for k in keys]
96
+ preds = [image_dict[k]['preds']for k in keys]
97
+ acc = 0
98
+ for y,pred in zip(ys,preds):
99
+ if pred < thresh and y == 0.:
100
+ acc+=1
101
+ elif pred > thresh and y == 1.:
102
+ acc+=1
103
+ return acc/len(preds)
104
+
105
+
106
+ def get_auc_score(preds_image_folder, root_fol, retAcc = False, acc_thresh = 0.3):
107
+ im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol)
108
+ if retAcc:
109
+ return get_auc_score_from_imdict(im_dict), get_accuracy_from_imdict(im_dict, acc_thresh)
110
+ else:
111
+ return get_auc_score_from_imdict(im_dict)
112
+
113
+ if __name__ == '__main__':
114
+ seed = '42' if len(sys.argv)== 1 else sys.argv[1]
115
+
116
+ root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test'
117
+
118
+ auc_score = get_auc_score(f'preds_{seed}',root_fol)
119
+ print(f'ROC AUC Score: {auc_score}')
120
+
DenseMammogram/dataloaders.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Get the dataloaders
2
+ # There are only two types of dataloaders, viz. VanillaFRCNN and BilaterialFRCNN
3
+
4
+ import torch
5
+ import cv2
6
+ import torchvision.transforms as T
7
+ import detection.transforms as transforms
8
+ from torch.utils.data import Dataset,DataLoader
9
+ import detection.utils as utils
10
+ import os
11
+ from tqdm import tqdm
12
+ import pandas as pd
13
+ from os.path import join
14
+ # VanillaFRCNN DataLoaders
15
+
16
+ class FRCNNDataset(Dataset):
17
+ def __init__(self,inputs,transform):
18
+ self.transform = transform
19
+ self.dataset_dicts = inputs
20
+
21
+ def __len__(self):
22
+ return len(self.dataset_dicts)
23
+
24
+
25
+ def __getitem__(self,index: int):
26
+ # Select the sample
27
+ record = self.dataset_dicts[index]
28
+ # Load input and target
29
+ img = cv2.imread(record['file_name'])
30
+
31
+ target = {k:torch.tensor(v) for k,v in record.items() if k != 'file_name'}
32
+ if self.transform is not None:
33
+ img = T.ToPILImage()(img)
34
+ img,target = self.transform(img,target)
35
+
36
+ return img,target
37
+
38
+ def xml_to_dicts(paths):
39
+ dataset_dicts = []
40
+ i=1
41
+ for path in paths:
42
+ for image in tqdm(os.listdir(os.path.join(path,'mal/images/'))):
43
+ xmlfile = os.path.join(path,'mal/gt/',image[:-4]+'.txt')
44
+ if(not os.path.exists(xmlfile)):
45
+ continue
46
+ img = cv2.imread(os.path.join(path,'mal/images/',image))
47
+ record = {}
48
+ record['file_name'] = os.path.join(path , 'mal/images/',image)
49
+ record['image_id'] = i
50
+ i+=1
51
+ record['width'] = img.shape[1]
52
+ record['height'] = img.shape[0]
53
+ objs = []
54
+ boxes = []
55
+ labels = []
56
+ area = []
57
+ iscrowd = []
58
+ f = open(xmlfile,'r')
59
+ for line in f.readlines():
60
+ box = list(map(int,map(float,line.split()[1:])))
61
+ boxes.append(box)
62
+ labels.append(1)
63
+ area.append((box[2]-box[0])*(box[3]-box[1]))
64
+ iscrowd.append(False)
65
+ f.close()
66
+ record["boxes"] = boxes
67
+ record["labels"] = labels
68
+ record["area"] = area
69
+ record["iscrowd"] = iscrowd
70
+ if(len(boxes)>0):
71
+ dataset_dicts.append(record)
72
+ for image in tqdm(os.listdir(os.path.join(path,'ben/images/'))):
73
+ img = cv2.imread(os.path.join(path,'ben/images/',image))
74
+ record = {}
75
+ record['file_name'] = os.path.join(path, 'ben/images/',image)
76
+ record['image_id'] = i
77
+ i+=1
78
+ record['width'] = img.shape[1]
79
+ record['height'] = img.shape[0]
80
+ record['boxes'] = torch.tensor([[0,0,img.shape[1],img.shape[0]]])
81
+ record['labels'] = torch.tensor([0])
82
+ record['area'] = [img.shape[1]*img.shape[0]]
83
+ record["iscrowd"] = [False]
84
+ dataset_dicts.append(record)
85
+ return dataset_dicts
86
+
87
+
88
+
89
+ def get_FRCNN_dataloaders(cfg, batch_size = 2, data_dir = '../bilateral_new',):
90
+ transform_test = transforms.Compose([transforms.ToTensor()])
91
+ transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
92
+ train_paths = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_TRAIN_SPLIT']),join(data_dir,cfg['DDSM_DATA'],cfg['DDSM_TRAIN_SPLIT']),]
93
+ val_aiims_path = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_VAL_SPLIT'])]
94
+ train_data = FRCNNDataset(xml_to_dicts(train_paths),transform_train)
95
+ test_aiims = FRCNNDataset(xml_to_dicts(val_aiims_path),transform_test)
96
+
97
+ train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
98
+ test_aiims_loader = DataLoader(test_aiims,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
99
+ #test_ddsm_loader = DataLoader(test_ddsm,batch_size=2,shuffle=True,drop_last=True,num_workers=5,collate_fn = utils.collate_fn)
100
+
101
+ return train_loader, test_aiims_loader
102
+
103
+ # BilaterialFRCNN DataLoaders
104
+
105
+ def get_direction(dset,file_name):
106
+ # 1 if right else -1
107
+ if dset == 'aiims' or dset == 'ddsm':
108
+ file_name = file_name.lower()
109
+ r = file_name.find('right')
110
+ l = file_name.find('left')
111
+ if l == r and l == -1:
112
+ raise Exception(f'Unidentifiable Direction {file_name}')
113
+ if l!=-1 and r!=-1:
114
+ raise Exception(f'Unidentifiable Direction {file_name}')
115
+ return 1 if r!=-1 else -1
116
+ if dset == 'inbreast':
117
+ dir =file_name.split('_')[3]
118
+ if dir == 'R': return 1
119
+ if dir == 'L': return -1
120
+ raise Exception(f'Unidentifiable Direction {file_name}')
121
+ if dset == 'irch':
122
+ r = file_name.find('_R ')
123
+ l = file_name.find('_L ')
124
+ if l == r and l == -1:
125
+ raise Exception(f'Unidentifiable Direction {file_name}')
126
+ if l!=-1 and r!=-1:
127
+ raise Exception(f'Unidentifiable Direction {file_name}')
128
+ return 1 if r!=-1 else -1
129
+
130
+
131
+ class BilateralDataset(torch.utils.data.Dataset):
132
+
133
+ def __init__(self,inputs,transform,dset):
134
+ self.transform = transform
135
+ self.dataset_dicts = inputs
136
+ self.dset = dset
137
+
138
+ def __len__(self):
139
+ return len(self.dataset_dicts)
140
+
141
+
142
+ def __getitem__(self,index: int):
143
+ # Select the sample
144
+ record = self.dataset_dicts[index]
145
+ # Load input and target
146
+ img1 = cv2.imread(record['file_name'])
147
+ img2 = cv2.imread(record['file_2'])
148
+
149
+ target = {k:torch.tensor(v) for k,v in record.items() if k != 'file_name' and k!='file_2'}
150
+ if self.transform is not None:
151
+ img1 = T.ToPILImage()(img1)
152
+ img2 = T.ToPILImage()(img2)
153
+ if(get_direction(self.dset,record['file_name'].split('/')[-1])==1):
154
+ img1,target = transforms.RandomHorizontalFlip(1.0)(img1,target)
155
+ else:
156
+ img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
157
+ img1,target = self.transform(img1,target)
158
+ img2,target = self.transform(img2,target)
159
+
160
+ images = [img1,img2]
161
+ return images,target
162
+
163
+
164
+ def xml_to_dicts_bilateral(paths,cor_dicts):
165
+ dataset_dicts = []
166
+ i=1
167
+ for path,cor_dict in zip(paths,cor_dicts):
168
+ for image in tqdm(os.listdir(os.path.join(path,'mal/images/'))):
169
+ if(not os.path.join(path,'mal/images/',image) in cor_dict):
170
+ continue
171
+ if(not os.path.isfile(cor_dict[os.path.join(path,'mal/images/',image)])):
172
+ continue
173
+ xmlfile = os.path.join(path,'mal/gt/',image[:-4]+'.txt')
174
+ if(not os.path.exists(xmlfile)):
175
+ continue
176
+ img = cv2.imread(os.path.join(path,'mal/images/',image))
177
+
178
+ record = {}
179
+ record['file_name'] = os.path.join(path , 'mal/images/',image)
180
+ record['file_2'] = cor_dict[os.path.join(path,'mal/images/',image)]
181
+ record['image_id'] = i
182
+ i+=1
183
+ record['width'] = img.shape[1]
184
+ record['height'] = img.shape[0]
185
+ objs = []
186
+ boxes = []
187
+ labels = []
188
+ area = []
189
+ iscrowd = []
190
+ f = open(xmlfile,'r')
191
+ for line in f.readlines():
192
+ box = list(map(int,map(float,line.split()[1:])))
193
+ boxes.append(box)
194
+ labels.append(1)
195
+ area.append((box[2]-box[0])*(box[3]-box[1]))
196
+ iscrowd.append(False)
197
+
198
+ f.close()
199
+ record["boxes"] = boxes
200
+ record["labels"] = labels
201
+ record["area"] = area
202
+ record["iscrowd"] = iscrowd
203
+ if(len(boxes)>0):
204
+ dataset_dicts.append(record)
205
+
206
+ for image in tqdm(os.listdir(os.path.join(path,'ben/images/'))):
207
+ if(not os.path.join(path,'ben/images/',image) in cor_dict):
208
+ continue
209
+ if(not os.path.isfile(cor_dict[os.path.join(path,'ben/images/',image)])):
210
+ continue
211
+ img = cv2.imread(os.path.join(path,'ben/images/',image))
212
+
213
+ record = {}
214
+ record['file_name'] = os.path.join(path , 'ben/images/',image)
215
+ record['file_2'] = cor_dict[os.path.join(path,'ben/images/',image)]
216
+ img2 = cv2.imread(cor_dict[os.path.join(path,'ben/images/',image)])
217
+ record['image_id'] = i
218
+ i+=1
219
+ record['width'] = img.shape[1]
220
+ record['height'] = img.shape[0]
221
+
222
+ record["boxes"] = torch.tensor([[0,0,min(img.shape[1],img2.shape[1]),min(img.shape[0],img2.shape[0])]])
223
+ record['labels'] = torch.tensor([0])
224
+ record['area'] = [ min(img.shape[1],img2.shape[1]) *min(img.shape[0],img2.shape[0])]
225
+ record["iscrowd"] = [False]
226
+ if(len(boxes)>0):
227
+ dataset_dicts.append(record)
228
+
229
+ return dataset_dicts
230
+
231
+
232
+
233
+ def get_dict(data_dir, filename):
234
+ df = pd.read_csv(filename, header=None, sep=r'\s+', quotechar='"').to_numpy()
235
+ cor_dict = dict()
236
+ for a in df:
237
+ if(a[0]==a[1]):
238
+ continue
239
+ cor_dict[a[0]] = a[1]
240
+ # print(cor_dict)
241
+ cor_dict = {join(data_dir,k):join(data_dir,v) for k,v in cor_dict.items()}
242
+ return cor_dict
243
+
244
+ def get_bilateral_dataloaders(cfg, batch_size = 1, data_dir = '../bilateral_new'):
245
+ transform_test = transforms.Compose([transforms.ToTensor()])
246
+ transform_train = transforms.Compose([transforms.RandomHorizontalFlip(),transforms.ToTensor()])
247
+ train_paths = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_TRAIN_SPLIT']),join(data_dir,cfg['DDSM_DATA'],cfg['DDSM_TRAIN_SPLIT']),]
248
+ val_aiims_path = [join(data_dir,cfg['AIIMS_DATA'],cfg['AIIMS_VAL_SPLIT'])]
249
+ cor_lists_train = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST'])),get_dict(data_dir,join(data_dir,cfg['DDSM_CORRS_LIST']))]
250
+ cor_lists_val = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST']))]
251
+ cor_lists_train = [get_dict(data_dir,join(data_dir,cfg['AIIMS_CORRS_LIST']))]
252
+ train_data = BilateralDataset(xml_to_dicts_bilateral(train_paths,cor_lists_train),transform_test,'aiims')
253
+ val_aiims = BilateralDataset(xml_to_dicts_bilateral(val_aiims_path,cor_lists_val),transform_test,'aiims')
254
+
255
+ train_loader = DataLoader(train_data,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
256
+ val_aiims_loader = DataLoader(val_aiims,batch_size=batch_size,shuffle=True,drop_last=True,num_workers=4,collate_fn = utils.collate_fn)
257
+ #test_ddsm_loader = DataLoader(test_ddsm,batch_size=2,shuffle=True,drop_last=True,num_workers=5,collate_fn = utils.collate_fn)
258
+
259
+ return train_loader, val_aiims_loader
DenseMammogram/detection/README.md ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Object detection reference training scripts
2
+
3
+ This folder contains reference training scripts for object detection.
4
+ They serve as a log of how to train specific models, to provide baseline
5
+ training and evaluation scripts to quickly bootstrap research.
6
+
7
+ To execute the example commands below you must install the following:
8
+
9
+ ```
10
+ cython
11
+ pycocotools
12
+ matplotlib
13
+ ```
14
+
15
+ You must modify the following flags:
16
+
17
+ `--data-path=/path/to/coco/dataset`
18
+
19
+ `--nproc_per_node=<number_of_gpus_available>`
20
+
21
+ Except otherwise noted, all models have been trained on 8x V100 GPUs.
22
+
23
+ ### Faster R-CNN ResNet-50 FPN
24
+ ```
25
+ torchrun --nproc_per_node=8 train.py\
26
+ --dataset coco --model fasterrcnn_resnet50_fpn --epochs 26\
27
+ --lr-steps 16 22 --aspect-ratio-group-factor 3
28
+ ```
29
+
30
+ ### Faster R-CNN MobileNetV3-Large FPN
31
+ ```
32
+ torchrun --nproc_per_node=8 train.py\
33
+ --dataset coco --model fasterrcnn_mobilenet_v3_large_fpn --epochs 26\
34
+ --lr-steps 16 22 --aspect-ratio-group-factor 3
35
+ ```
36
+
37
+ ### Faster R-CNN MobileNetV3-Large 320 FPN
38
+ ```
39
+ torchrun --nproc_per_node=8 train.py\
40
+ --dataset coco --model fasterrcnn_mobilenet_v3_large_320_fpn --epochs 26\
41
+ --lr-steps 16 22 --aspect-ratio-group-factor 3
42
+ ```
43
+
44
+ ### RetinaNet
45
+ ```
46
+ torchrun --nproc_per_node=8 train.py\
47
+ --dataset coco --model retinanet_resnet50_fpn --epochs 26\
48
+ --lr-steps 16 22 --aspect-ratio-group-factor 3 --lr 0.01
49
+ ```
50
+
51
+ ### SSD300 VGG16
52
+ ```
53
+ torchrun --nproc_per_node=8 train.py\
54
+ --dataset coco --model ssd300_vgg16 --epochs 120\
55
+ --lr-steps 80 110 --aspect-ratio-group-factor 3 --lr 0.002 --batch-size 4\
56
+ --weight-decay 0.0005 --data-augmentation ssd
57
+ ```
58
+
59
+ ### SSDlite320 MobileNetV3-Large
60
+ ```
61
+ torchrun --nproc_per_node=8 train.py\
62
+ --dataset coco --model ssdlite320_mobilenet_v3_large --epochs 660\
63
+ --aspect-ratio-group-factor 3 --lr-scheduler cosineannealinglr --lr 0.15 --batch-size 24\
64
+ --weight-decay 0.00004 --data-augmentation ssdlite
65
+ ```
66
+
67
+
68
+ ### Mask R-CNN
69
+ ```
70
+ torchrun --nproc_per_node=8 train.py\
71
+ --dataset coco --model maskrcnn_resnet50_fpn --epochs 26\
72
+ --lr-steps 16 22 --aspect-ratio-group-factor 3
73
+ ```
74
+
75
+
76
+ ### Keypoint R-CNN
77
+ ```
78
+ torchrun --nproc_per_node=8 train.py\
79
+ --dataset coco_kp --model keypointrcnn_resnet50_fpn --epochs 46\
80
+ --lr-steps 36 43 --aspect-ratio-group-factor 3
81
+ ```
DenseMammogram/detection/coco_eval.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import io
3
+ from contextlib import redirect_stdout
4
+
5
+ import numpy as np
6
+ import pycocotools.mask as mask_util
7
+ import torch
8
+ import detection.utils as utils
9
+ from pycocotools.coco import COCO
10
+ from pycocotools.cocoeval import COCOeval
11
+
12
+
13
+ class CocoEvaluator:
14
+ def __init__(self, coco_gt, iou_types):
15
+ assert isinstance(iou_types, (list, tuple))
16
+ coco_gt = copy.deepcopy(coco_gt)
17
+ self.coco_gt = coco_gt
18
+
19
+ self.iou_types = iou_types
20
+ self.coco_eval = {}
21
+ for iou_type in iou_types:
22
+ self.coco_eval[iou_type] = COCOeval(coco_gt, iouType=iou_type)
23
+
24
+ self.img_ids = []
25
+ self.eval_imgs = {k: [] for k in iou_types}
26
+
27
+ def update(self, predictions):
28
+ img_ids = list(np.unique(list(predictions.keys())))
29
+ self.img_ids.extend(img_ids)
30
+
31
+ for iou_type in self.iou_types:
32
+ results = self.prepare(predictions, iou_type)
33
+ with redirect_stdout(io.StringIO()):
34
+ coco_dt = COCO.loadRes(self.coco_gt, results) if results else COCO()
35
+ coco_eval = self.coco_eval[iou_type]
36
+
37
+ coco_eval.cocoDt = coco_dt
38
+ coco_eval.params.imgIds = list(img_ids)
39
+ img_ids, eval_imgs = evaluate(coco_eval)
40
+
41
+ self.eval_imgs[iou_type].append(eval_imgs)
42
+
43
+ def synchronize_between_processes(self):
44
+ for iou_type in self.iou_types:
45
+ self.eval_imgs[iou_type] = np.concatenate(self.eval_imgs[iou_type], 2)
46
+ create_common_coco_eval(self.coco_eval[iou_type], self.img_ids, self.eval_imgs[iou_type])
47
+
48
+ def accumulate(self):
49
+ for coco_eval in self.coco_eval.values():
50
+ coco_eval.accumulate()
51
+
52
+ def summarize(self):
53
+ for iou_type, coco_eval in self.coco_eval.items():
54
+ print(f"IoU metric: {iou_type}")
55
+ coco_eval.summarize()
56
+
57
+ def prepare(self, predictions, iou_type):
58
+ if iou_type == "bbox":
59
+ return self.prepare_for_coco_detection(predictions)
60
+ if iou_type == "segm":
61
+ return self.prepare_for_coco_segmentation(predictions)
62
+ if iou_type == "keypoints":
63
+ return self.prepare_for_coco_keypoint(predictions)
64
+ raise ValueError(f"Unknown iou type {iou_type}")
65
+
66
+ def prepare_for_coco_detection(self, predictions):
67
+ coco_results = []
68
+ for original_id, prediction in predictions.items():
69
+ if len(prediction) == 0:
70
+ continue
71
+
72
+ boxes = prediction["boxes"]
73
+ boxes = convert_to_xywh(boxes).tolist()
74
+ scores = prediction["scores"].tolist()
75
+ labels = prediction["labels"].tolist()
76
+
77
+ coco_results.extend(
78
+ [
79
+ {
80
+ "image_id": original_id,
81
+ "category_id": labels[k],
82
+ "bbox": box,
83
+ "score": scores[k],
84
+ }
85
+ for k, box in enumerate(boxes)
86
+ ]
87
+ )
88
+ return coco_results
89
+
90
+ def prepare_for_coco_segmentation(self, predictions):
91
+ coco_results = []
92
+ for original_id, prediction in predictions.items():
93
+ if len(prediction) == 0:
94
+ continue
95
+
96
+ scores = prediction["scores"]
97
+ labels = prediction["labels"]
98
+ masks = prediction["masks"]
99
+
100
+ masks = masks > 0.5
101
+
102
+ scores = prediction["scores"].tolist()
103
+ labels = prediction["labels"].tolist()
104
+
105
+ rles = [
106
+ mask_util.encode(np.array(mask[0, :, :, np.newaxis], dtype=np.uint8, order="F"))[0] for mask in masks
107
+ ]
108
+ for rle in rles:
109
+ rle["counts"] = rle["counts"].decode("utf-8")
110
+
111
+ coco_results.extend(
112
+ [
113
+ {
114
+ "image_id": original_id,
115
+ "category_id": labels[k],
116
+ "segmentation": rle,
117
+ "score": scores[k],
118
+ }
119
+ for k, rle in enumerate(rles)
120
+ ]
121
+ )
122
+ return coco_results
123
+
124
+ def prepare_for_coco_keypoint(self, predictions):
125
+ coco_results = []
126
+ for original_id, prediction in predictions.items():
127
+ if len(prediction) == 0:
128
+ continue
129
+
130
+ boxes = prediction["boxes"]
131
+ boxes = convert_to_xywh(boxes).tolist()
132
+ scores = prediction["scores"].tolist()
133
+ labels = prediction["labels"].tolist()
134
+ keypoints = prediction["keypoints"]
135
+ keypoints = keypoints.flatten(start_dim=1).tolist()
136
+
137
+ coco_results.extend(
138
+ [
139
+ {
140
+ "image_id": original_id,
141
+ "category_id": labels[k],
142
+ "keypoints": keypoint,
143
+ "score": scores[k],
144
+ }
145
+ for k, keypoint in enumerate(keypoints)
146
+ ]
147
+ )
148
+ return coco_results
149
+
150
+
151
+ def convert_to_xywh(boxes):
152
+ xmin, ymin, xmax, ymax = boxes.unbind(1)
153
+ return torch.stack((xmin, ymin, xmax - xmin, ymax - ymin), dim=1)
154
+
155
+
156
+ def merge(img_ids, eval_imgs):
157
+ all_img_ids = utils.all_gather(img_ids)
158
+ all_eval_imgs = utils.all_gather(eval_imgs)
159
+
160
+ merged_img_ids = []
161
+ for p in all_img_ids:
162
+ merged_img_ids.extend(p)
163
+
164
+ merged_eval_imgs = []
165
+ for p in all_eval_imgs:
166
+ merged_eval_imgs.append(p)
167
+
168
+ merged_img_ids = np.array(merged_img_ids)
169
+ merged_eval_imgs = np.concatenate(merged_eval_imgs, 2)
170
+
171
+ # keep only unique (and in sorted order) images
172
+ merged_img_ids, idx = np.unique(merged_img_ids, return_index=True)
173
+ merged_eval_imgs = merged_eval_imgs[..., idx]
174
+
175
+ return merged_img_ids, merged_eval_imgs
176
+
177
+
178
+ def create_common_coco_eval(coco_eval, img_ids, eval_imgs):
179
+ img_ids, eval_imgs = merge(img_ids, eval_imgs)
180
+ img_ids = list(img_ids)
181
+ eval_imgs = list(eval_imgs.flatten())
182
+
183
+ coco_eval.evalImgs = eval_imgs
184
+ coco_eval.params.imgIds = img_ids
185
+ coco_eval._paramsEval = copy.deepcopy(coco_eval.params)
186
+
187
+
188
+ def evaluate(imgs):
189
+ with redirect_stdout(io.StringIO()):
190
+ imgs.evaluate()
191
+ return imgs.params.imgIds, np.asarray(imgs.evalImgs).reshape(-1, len(imgs.params.areaRng), len(imgs.params.imgIds))
DenseMammogram/detection/coco_utils.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import os
3
+
4
+ import torch
5
+ import torch.utils.data
6
+ import torchvision
7
+ import detection.transforms as T
8
+ from pycocotools import mask as coco_mask
9
+ from pycocotools.coco import COCO
10
+
11
+
12
+ class FilterAndRemapCocoCategories:
13
+ def __init__(self, categories, remap=True):
14
+ self.categories = categories
15
+ self.remap = remap
16
+
17
+ def __call__(self, image, target):
18
+ anno = target["annotations"]
19
+ anno = [obj for obj in anno if obj["category_id"] in self.categories]
20
+ if not self.remap:
21
+ target["annotations"] = anno
22
+ return image, target
23
+ anno = copy.deepcopy(anno)
24
+ for obj in anno:
25
+ obj["category_id"] = self.categories.index(obj["category_id"])
26
+ target["annotations"] = anno
27
+ return image, target
28
+
29
+
30
+ def convert_coco_poly_to_mask(segmentations, height, width):
31
+ masks = []
32
+ for polygons in segmentations:
33
+ rles = coco_mask.frPyObjects(polygons, height, width)
34
+ mask = coco_mask.decode(rles)
35
+ if len(mask.shape) < 3:
36
+ mask = mask[..., None]
37
+ mask = torch.as_tensor(mask, dtype=torch.uint8)
38
+ mask = mask.any(dim=2)
39
+ masks.append(mask)
40
+ if masks:
41
+ masks = torch.stack(masks, dim=0)
42
+ else:
43
+ masks = torch.zeros((0, height, width), dtype=torch.uint8)
44
+ return masks
45
+
46
+
47
+ class ConvertCocoPolysToMask:
48
+ def __call__(self, image, target):
49
+ w, h = image.size
50
+
51
+ image_id = target["image_id"]
52
+ image_id = torch.tensor([image_id])
53
+
54
+ anno = target["annotations"]
55
+
56
+ anno = [obj for obj in anno if obj["iscrowd"] == 0]
57
+
58
+ boxes = [obj["bbox"] for obj in anno]
59
+ # guard against no boxes via resizing
60
+ boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4)
61
+ boxes[:, 2:] += boxes[:, :2]
62
+ boxes[:, 0::2].clamp_(min=0, max=w)
63
+ boxes[:, 1::2].clamp_(min=0, max=h)
64
+
65
+ classes = [obj["category_id"] for obj in anno]
66
+ classes = torch.tensor(classes, dtype=torch.int64)
67
+
68
+ segmentations = [obj["segmentation"] for obj in anno]
69
+ masks = convert_coco_poly_to_mask(segmentations, h, w)
70
+
71
+ keypoints = None
72
+ if anno and "keypoints" in anno[0]:
73
+ keypoints = [obj["keypoints"] for obj in anno]
74
+ keypoints = torch.as_tensor(keypoints, dtype=torch.float32)
75
+ num_keypoints = keypoints.shape[0]
76
+ if num_keypoints:
77
+ keypoints = keypoints.view(num_keypoints, -1, 3)
78
+
79
+ keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0])
80
+ boxes = boxes[keep]
81
+ classes = classes[keep]
82
+ masks = masks[keep]
83
+ if keypoints is not None:
84
+ keypoints = keypoints[keep]
85
+
86
+ target = {}
87
+ target["boxes"] = boxes
88
+ target["labels"] = classes
89
+ target["masks"] = masks
90
+ target["image_id"] = image_id
91
+ if keypoints is not None:
92
+ target["keypoints"] = keypoints
93
+
94
+ # for conversion to coco api
95
+ area = torch.tensor([obj["area"] for obj in anno])
96
+ iscrowd = torch.tensor([obj["iscrowd"] for obj in anno])
97
+ target["area"] = area
98
+ target["iscrowd"] = iscrowd
99
+
100
+ return image, target
101
+
102
+
103
+ def _coco_remove_images_without_annotations(dataset, cat_list=None):
104
+ def _has_only_empty_bbox(anno):
105
+ return all(any(o <= 1 for o in obj["bbox"][2:]) for obj in anno)
106
+
107
+ def _count_visible_keypoints(anno):
108
+ return sum(sum(1 for v in ann["keypoints"][2::3] if v > 0) for ann in anno)
109
+
110
+ min_keypoints_per_image = 10
111
+
112
+ def _has_valid_annotation(anno):
113
+ # if it's empty, there is no annotation
114
+ if len(anno) == 0:
115
+ return False
116
+ # if all boxes have close to zero area, there is no annotation
117
+ if _has_only_empty_bbox(anno):
118
+ return False
119
+ # keypoints task have a slight different critera for considering
120
+ # if an annotation is valid
121
+ if "keypoints" not in anno[0]:
122
+ return True
123
+ # for keypoint detection tasks, only consider valid images those
124
+ # containing at least min_keypoints_per_image
125
+ if _count_visible_keypoints(anno) >= min_keypoints_per_image:
126
+ return True
127
+ return False
128
+
129
+ assert isinstance(dataset, torchvision.datasets.CocoDetection)
130
+ ids = []
131
+ for ds_idx, img_id in enumerate(dataset.ids):
132
+ ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
133
+ anno = dataset.coco.loadAnns(ann_ids)
134
+ if cat_list:
135
+ anno = [obj for obj in anno if obj["category_id"] in cat_list]
136
+ if _has_valid_annotation(anno):
137
+ ids.append(ds_idx)
138
+
139
+ dataset = torch.utils.data.Subset(dataset, ids)
140
+ return dataset
141
+
142
+
143
+ def convert_to_coco_api(ds):
144
+ coco_ds = COCO()
145
+ # annotation IDs need to start at 1, not 0, see torchvision issue #1530
146
+ ann_id = 1
147
+ dataset = {"images": [], "categories": [], "annotations": []}
148
+ categories = set()
149
+ for img_idx in range(len(ds)):
150
+ # find better way to get target
151
+ # targets = ds.get_annotations(img_idx)
152
+ img, targets = ds[img_idx]
153
+ image_id = targets["image_id"].item()
154
+ img_dict = {}
155
+ img_dict["id"] = image_id
156
+ img_dict["height"] = img.shape[-2]
157
+ img_dict["width"] = img.shape[-1]
158
+ dataset["images"].append(img_dict)
159
+ bboxes = targets["boxes"]
160
+ bboxes[:, 2:] -= bboxes[:, :2]
161
+ bboxes = bboxes.tolist()
162
+ labels = targets["labels"].tolist()
163
+ areas = targets["area"].tolist()
164
+ iscrowd = targets["iscrowd"].tolist()
165
+ if "masks" in targets:
166
+ masks = targets["masks"]
167
+ # make masks Fortran contiguous for coco_mask
168
+ masks = masks.permute(0, 2, 1).contiguous().permute(0, 2, 1)
169
+ if "keypoints" in targets:
170
+ keypoints = targets["keypoints"]
171
+ keypoints = keypoints.reshape(keypoints.shape[0], -1).tolist()
172
+ num_objs = len(bboxes)
173
+ for i in range(num_objs):
174
+ ann = {}
175
+ ann["image_id"] = image_id
176
+ ann["bbox"] = bboxes[i]
177
+ ann["category_id"] = labels[i]
178
+ categories.add(labels[i])
179
+ ann["area"] = areas[i]
180
+ ann["iscrowd"] = iscrowd[i]
181
+ ann["id"] = ann_id
182
+ if "masks" in targets:
183
+ ann["segmentation"] = coco_mask.encode(masks[i].numpy())
184
+ if "keypoints" in targets:
185
+ ann["keypoints"] = keypoints[i]
186
+ ann["num_keypoints"] = sum(k != 0 for k in keypoints[i][2::3])
187
+ dataset["annotations"].append(ann)
188
+ ann_id += 1
189
+ dataset["categories"] = [{"id": i} for i in sorted(categories)]
190
+ coco_ds.dataset = dataset
191
+ coco_ds.createIndex()
192
+ return coco_ds
193
+
194
+
195
+ def get_coco_api_from_dataset(dataset):
196
+ for _ in range(10):
197
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
198
+ break
199
+ if isinstance(dataset, torch.utils.data.Subset):
200
+ dataset = dataset.dataset
201
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
202
+ return dataset.coco
203
+ return convert_to_coco_api(dataset)
204
+
205
+
206
+ class CocoDetection(torchvision.datasets.CocoDetection):
207
+ def __init__(self, img_folder, ann_file, transforms):
208
+ super().__init__(img_folder, ann_file)
209
+ self._transforms = transforms
210
+
211
+ def __getitem__(self, idx):
212
+ img, target = super().__getitem__(idx)
213
+ image_id = self.ids[idx]
214
+ target = dict(image_id=image_id, annotations=target)
215
+ if self._transforms is not None:
216
+ img, target = self._transforms(img, target)
217
+ return img, target
218
+
219
+
220
+ def get_coco(root, image_set, transforms, mode="instances"):
221
+ anno_file_template = "{}_{}2017.json"
222
+ PATHS = {
223
+ "train": ("train2017", os.path.join("annotations", anno_file_template.format(mode, "train"))),
224
+ "val": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val"))),
225
+ # "train": ("val2017", os.path.join("annotations", anno_file_template.format(mode, "val")))
226
+ }
227
+
228
+ t = [ConvertCocoPolysToMask()]
229
+
230
+ if transforms is not None:
231
+ t.append(transforms)
232
+ transforms = T.Compose(t)
233
+
234
+ img_folder, ann_file = PATHS[image_set]
235
+ img_folder = os.path.join(root, img_folder)
236
+ ann_file = os.path.join(root, ann_file)
237
+
238
+ dataset = CocoDetection(img_folder, ann_file, transforms=transforms)
239
+
240
+ if image_set == "train":
241
+ dataset = _coco_remove_images_without_annotations(dataset)
242
+
243
+ # dataset = torch.utils.data.Subset(dataset, [i for i in range(500)])
244
+
245
+ return dataset
246
+
247
+
248
+ def get_coco_kp(root, image_set, transforms):
249
+ return get_coco(root, image_set, transforms, mode="person_keypoints")
DenseMammogram/detection/engine.py ADDED
@@ -0,0 +1,276 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import sys
3
+ import time
4
+
5
+ import torch
6
+ import torchvision.models.detection.mask_rcnn
7
+ import detection.utils as utils
8
+ from detection.coco_eval import CocoEvaluator
9
+ from detection.coco_utils import get_coco_api_from_dataset
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+
13
+
14
+ sys.path.append("..")
15
+ from utils import AverageMeter
16
+ from advanced_logger import LogPriority
17
+
18
+ def train_one_epoch(model, optimizer, data_loader, device, epoch, print_freq, scaler=None):
19
+ model.train()
20
+ metric_logger = utils.MetricLogger(delimiter=" ")
21
+ metric_logger.add_meter("lr", utils.SmoothedValue(window_size=1, fmt="{value:.6f}"))
22
+ header = f"Epoch: [{epoch}]"
23
+
24
+ lr_scheduler = None
25
+ if epoch == 0:
26
+ warmup_factor = 1.0 / 1000
27
+ warmup_iters = min(1000, len(data_loader) - 1)
28
+
29
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
30
+ optimizer, start_factor=warmup_factor, total_iters=warmup_iters
31
+ )
32
+ #for batch_idx,(images, targets) in enumerate(tqdm(data_loader)):
33
+ for images, targets in metric_logger.log_every(data_loader, print_freq, header):
34
+ #print(images.shape)
35
+ images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
36
+ #print(len(images))
37
+ #print(images[0].shape)
38
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
39
+ with torch.cuda.amp.autocast(enabled=scaler is not None):
40
+ loss_dict = model(images, targets)
41
+ losses = sum(loss for loss in loss_dict.values())
42
+
43
+ # reduce losses over all GPUs for logging purposes
44
+ loss_dict_reduced = utils.reduce_dict(loss_dict)
45
+ losses_reduced = sum(loss for loss in loss_dict_reduced.values())
46
+
47
+ loss_value = losses_reduced.item()
48
+
49
+ if not math.isfinite(loss_value):
50
+ print(f"Loss is {loss_value}, stopping training")
51
+ print(loss_dict_reduced)
52
+ sys.exit(1)
53
+
54
+ optimizer.zero_grad()
55
+ if scaler is not None:
56
+ scaler.scale(losses).backward()
57
+ scaler.step(optimizer)
58
+ scaler.update()
59
+ else:
60
+ losses.backward()
61
+ optimizer.step()
62
+
63
+ if lr_scheduler is not None:
64
+ lr_scheduler.step()
65
+
66
+ #if(batch_idx%20==0):
67
+ # print('epoch {} batch {} : {}'.format(epoch,batch_idx,losses_reduced))
68
+
69
+ metric_logger.update(loss=losses_reduced, **loss_dict_reduced)
70
+ metric_logger.update(lr=optimizer.param_groups[0]["lr"])
71
+
72
+ return metric_logger
73
+
74
+
75
+ def train_one_epoch_simplified(model, optimizer, data_loader, device, epoch, experimenter,optimizer_backbone=None):
76
+
77
+ model.train()
78
+ lr_scheduler = None
79
+ lr_scheduler_backbone = None
80
+ if epoch == 0:
81
+ warmup_factor = 1.0 / 1000
82
+ warmup_iters = min(1000, len(data_loader) - 1)
83
+
84
+ lr_scheduler = torch.optim.lr_scheduler.LinearLR(
85
+ optimizer, start_factor=warmup_factor, total_iters=warmup_iters
86
+ )
87
+ if(optimizer_backbone is not None):
88
+ lr_scheduler_backbone = torch.optim.lr_scheduler.LinearLR(optimizer_backbone, start_factor=warmup_factor, total_iters=warmup_iters)
89
+
90
+
91
+ loss_meter = AverageMeter()
92
+
93
+ for step, (images, targets) in enumerate(tqdm(data_loader)):
94
+
95
+ optimizer.zero_grad()
96
+ if(optimizer_backbone is not None):
97
+ optimizer_backbone.zero_grad()
98
+
99
+ images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
100
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
101
+ loss_dict = model(images, targets)
102
+ losses = sum(loss for loss in loss_dict.values())
103
+
104
+
105
+ if not math.isfinite(losses.item()):
106
+ print(f"Loss is {losses.item()}, stopping training")
107
+ print(loss_dict)
108
+ experimenter.log(f"Loss is {losses.item()}, stopping training")
109
+ sys.exit(1)
110
+
111
+ losses.backward()
112
+ loss_meter.update(losses.item())
113
+ optimizer.step()
114
+ if optimizer_backbone is not None:
115
+ optimizer_backbone.step()
116
+ if lr_scheduler is not None:
117
+ lr_scheduler.step()
118
+ if lr_scheduler_backbone is not None:
119
+ lr_scheduler_backbone.step()
120
+
121
+ if (step+1)%10 == 0:
122
+ experimenter.log('Loss after {} steps: {}'.format(step+1, loss_meter.avg))
123
+ if epoch == 0 and (step+1)%50 == 0:
124
+ experimenter.log('LR after {} steps: {}'.format(step+1, optimizer.param_groups[0]['lr']))
125
+
126
+ def _get_iou_types(model):
127
+ model_without_ddp = model
128
+ if isinstance(model, torch.nn.parallel.DistributedDataParallel):
129
+ model_without_ddp = model.module
130
+ iou_types = ["bbox"]
131
+ if isinstance(model_without_ddp, torchvision.models.detection.MaskRCNN):
132
+ iou_types.append("segm")
133
+ if isinstance(model_without_ddp, torchvision.models.detection.KeypointRCNN):
134
+ iou_types.append("keypoints")
135
+ return iou_types
136
+
137
+
138
+ @torch.inference_mode()
139
+ def evaluate(model, data_loader, device):
140
+ n_threads = torch.get_num_threads()
141
+ # FIXME remove this and make paste_masks_in_image run on the GPU
142
+ torch.set_num_threads(1)
143
+ cpu_device = torch.device("cpu")
144
+ model.eval()
145
+ metric_logger = utils.MetricLogger(delimiter=" ")
146
+ header = "Test:"
147
+
148
+ coco = get_coco_api_from_dataset(data_loader.dataset)
149
+ iou_types = _get_iou_types(model)
150
+ coco_evaluator = CocoEvaluator(coco, iou_types)
151
+
152
+ for images, targets in metric_logger.log_every(data_loader, 100, header):
153
+ images = list(img.to(device) for img in images)
154
+
155
+ if torch.cuda.is_available():
156
+ torch.cuda.synchronize()
157
+ model_time = time.time()
158
+ outputs = model(images)
159
+
160
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
161
+ model_time = time.time() - model_time
162
+
163
+ res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
164
+ evaluator_time = time.time()
165
+ coco_evaluator.update(res)
166
+ evaluator_time = time.time() - evaluator_time
167
+ metric_logger.update(model_time=model_time, evaluator_time=evaluator_time)
168
+
169
+ # gather the stats from all processes
170
+ metric_logger.synchronize_between_processes()
171
+ print("Averaged stats:", metric_logger)
172
+ coco_evaluator.synchronize_between_processes()
173
+
174
+ # accumulate predictions from all images
175
+ coco_evaluator.accumulate()
176
+ coco_evaluator.summarize()
177
+ torch.set_num_threads(n_threads)
178
+ return coco_evaluator
179
+
180
+
181
+ def coco_summ(coco_eval, experimenter):
182
+ self = coco_eval
183
+ def _summarize( ap=1, iouThr=None, areaRng='all', maxDets=100 ):
184
+ p = self.params
185
+ iStr = ' {:<18} {} @[ IoU={:<9} | area={:>6s} | maxDets={:>3d} ] = {:0.3f}'
186
+ titleStr = 'Average Precision' if ap == 1 else 'Average Recall'
187
+ typeStr = '(AP)' if ap==1 else '(AR)'
188
+ iouStr = '{:0.2f}:{:0.2f}'.format(p.iouThrs[0], p.iouThrs[-1]) \
189
+ if iouThr is None else '{:0.2f}'.format(iouThr)
190
+
191
+ aind = [i for i, aRng in enumerate(p.areaRngLbl) if aRng == areaRng]
192
+ mind = [i for i, mDet in enumerate(p.maxDets) if mDet == maxDets]
193
+ if ap == 1:
194
+ # dimension of precision: [TxRxKxAxM]
195
+ s = self.eval['precision']
196
+ # IoU
197
+ if iouThr is not None:
198
+ t = np.where(iouThr == p.iouThrs)[0]
199
+ s = s[t]
200
+ s = s[:,:,:,aind,mind]
201
+ else:
202
+ # dimension of recall: [TxKxAxM]
203
+ s = self.eval['recall']
204
+ if iouThr is not None:
205
+ t = np.where(iouThr == p.iouThrs)[0]
206
+ s = s[t]
207
+ s = s[:,:,aind,mind]
208
+ if len(s[s>-1])==0:
209
+ mean_s = -1
210
+ else:
211
+ mean_s = np.mean(s[s>-1])
212
+ experimenter.log(iStr.format(titleStr, typeStr, iouStr, areaRng, maxDets, mean_s), priority = LogPriority.MEDIUM)
213
+ return mean_s
214
+ def _summarizeDets():
215
+ stats = np.zeros((12,))
216
+ stats[0] = _summarize(1)
217
+ stats[1] = _summarize(1, iouThr=.5, maxDets=self.params.maxDets[2])
218
+ stats[2] = _summarize(1, iouThr=.75, maxDets=self.params.maxDets[2])
219
+ stats[3] = _summarize(1, areaRng='small', maxDets=self.params.maxDets[2])
220
+ stats[4] = _summarize(1, areaRng='medium', maxDets=self.params.maxDets[2])
221
+ stats[5] = _summarize(1, areaRng='large', maxDets=self.params.maxDets[2])
222
+ stats[6] = _summarize(0, maxDets=self.params.maxDets[0])
223
+ stats[7] = _summarize(0, maxDets=self.params.maxDets[1])
224
+ stats[8] = _summarize(0, maxDets=self.params.maxDets[2])
225
+ stats[9] = _summarize(0, areaRng='small', maxDets=self.params.maxDets[2])
226
+ stats[10] = _summarize(0, areaRng='medium', maxDets=self.params.maxDets[2])
227
+ stats[11] = _summarize(0, areaRng='large', maxDets=self.params.maxDets[2])
228
+ return stats
229
+ _summarizeDets()
230
+
231
+ @torch.inference_mode()
232
+ def evaluate_simplified(model, data_loader, device, experimenter):
233
+ cpu_device = torch.device("cpu")
234
+ model.eval()
235
+ experimenter.log('Evaluating Validation Parameters')
236
+
237
+ coco = get_coco_api_from_dataset(data_loader.dataset)
238
+ iou_types = _get_iou_types(model)
239
+ coco_evaluator = CocoEvaluator(coco, iou_types)
240
+
241
+ for images, targets in data_loader:
242
+ images = list(img.to(device) for img in images)
243
+
244
+ if torch.cuda.is_available():
245
+ torch.cuda.synchronize()
246
+ outputs = model(images)
247
+ outputs = [{k: v.to(cpu_device) for k, v in t.items()} for t in outputs]
248
+ res = {target["image_id"].item(): output for target, output in zip(targets, outputs)}
249
+ coco_evaluator.update(res)
250
+
251
+ # gather the stats from all processes
252
+ coco_evaluator.synchronize_between_processes()
253
+
254
+ # accumulate predictions from all images
255
+ coco_evaluator.accumulate()
256
+
257
+ # Debug and see what all info it has
258
+ # coco_evaluator.summarize()
259
+ for iou_type, coco_eval in coco_evaluator.coco_eval.items():
260
+ print(f"IoU metric: {iou_type}")
261
+ coco_summ(coco_eval, experimenter)
262
+
263
+ return coco_evaluator
264
+
265
+ def evaluate_loss(model, device, val_loader, experimenter=None):
266
+ model.train()
267
+ #experimenter.log('Evaluating Validation Loss')
268
+ with torch.no_grad():
269
+ loss_meter = AverageMeter()
270
+ for images, targets in tqdm(val_loader):
271
+ images = list(image.to(device) if len(image)>2 else [image[0].to(device),image[1].to(device)] for image in images)
272
+ targets = [{k: v.to(device) for k, v in t.items()} for t in targets]
273
+ loss_dict = model(images, targets)
274
+ losses = sum(loss for loss in loss_dict.values())
275
+ loss_meter.update(losses.item())
276
+ return loss_meter.avg
DenseMammogram/detection/group_by_aspect_ratio.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bisect
2
+ import copy
3
+ import math
4
+ from collections import defaultdict
5
+ from itertools import repeat, chain
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.utils.data
10
+ import torchvision
11
+ from PIL import Image
12
+ from torch.utils.data.sampler import BatchSampler, Sampler
13
+ from torch.utils.model_zoo import tqdm
14
+
15
+
16
+ def _repeat_to_at_least(iterable, n):
17
+ repeat_times = math.ceil(n / len(iterable))
18
+ repeated = chain.from_iterable(repeat(iterable, repeat_times))
19
+ return list(repeated)
20
+
21
+
22
+ class GroupedBatchSampler(BatchSampler):
23
+ """
24
+ Wraps another sampler to yield a mini-batch of indices.
25
+ It enforces that the batch only contain elements from the same group.
26
+ It also tries to provide mini-batches which follows an ordering which is
27
+ as close as possible to the ordering from the original sampler.
28
+ Args:
29
+ sampler (Sampler): Base sampler.
30
+ group_ids (list[int]): If the sampler produces indices in range [0, N),
31
+ `group_ids` must be a list of `N` ints which contains the group id of each sample.
32
+ The group ids must be a continuous set of integers starting from
33
+ 0, i.e. they must be in the range [0, num_groups).
34
+ batch_size (int): Size of mini-batch.
35
+ """
36
+
37
+ def __init__(self, sampler, group_ids, batch_size):
38
+ if not isinstance(sampler, Sampler):
39
+ raise ValueError(f"sampler should be an instance of torch.utils.data.Sampler, but got sampler={sampler}")
40
+ self.sampler = sampler
41
+ self.group_ids = group_ids
42
+ self.batch_size = batch_size
43
+
44
+ def __iter__(self):
45
+ buffer_per_group = defaultdict(list)
46
+ samples_per_group = defaultdict(list)
47
+
48
+ num_batches = 0
49
+ for idx in self.sampler:
50
+ group_id = self.group_ids[idx]
51
+ buffer_per_group[group_id].append(idx)
52
+ samples_per_group[group_id].append(idx)
53
+ if len(buffer_per_group[group_id]) == self.batch_size:
54
+ yield buffer_per_group[group_id]
55
+ num_batches += 1
56
+ del buffer_per_group[group_id]
57
+ assert len(buffer_per_group[group_id]) < self.batch_size
58
+
59
+ # now we have run out of elements that satisfy
60
+ # the group criteria, let's return the remaining
61
+ # elements so that the size of the sampler is
62
+ # deterministic
63
+ expected_num_batches = len(self)
64
+ num_remaining = expected_num_batches - num_batches
65
+ if num_remaining > 0:
66
+ # for the remaining batches, take first the buffers with largest number
67
+ # of elements
68
+ for group_id, _ in sorted(buffer_per_group.items(), key=lambda x: len(x[1]), reverse=True):
69
+ remaining = self.batch_size - len(buffer_per_group[group_id])
70
+ samples_from_group_id = _repeat_to_at_least(samples_per_group[group_id], remaining)
71
+ buffer_per_group[group_id].extend(samples_from_group_id[:remaining])
72
+ assert len(buffer_per_group[group_id]) == self.batch_size
73
+ yield buffer_per_group[group_id]
74
+ num_remaining -= 1
75
+ if num_remaining == 0:
76
+ break
77
+ assert num_remaining == 0
78
+
79
+ def __len__(self):
80
+ return len(self.sampler) // self.batch_size
81
+
82
+
83
+ def _compute_aspect_ratios_slow(dataset, indices=None):
84
+ print(
85
+ "Your dataset doesn't support the fast path for "
86
+ "computing the aspect ratios, so will iterate over "
87
+ "the full dataset and load every image instead. "
88
+ "This might take some time..."
89
+ )
90
+ if indices is None:
91
+ indices = range(len(dataset))
92
+
93
+ class SubsetSampler(Sampler):
94
+ def __init__(self, indices):
95
+ self.indices = indices
96
+
97
+ def __iter__(self):
98
+ return iter(self.indices)
99
+
100
+ def __len__(self):
101
+ return len(self.indices)
102
+
103
+ sampler = SubsetSampler(indices)
104
+ data_loader = torch.utils.data.DataLoader(
105
+ dataset,
106
+ batch_size=1,
107
+ sampler=sampler,
108
+ num_workers=14, # you might want to increase it for faster processing
109
+ collate_fn=lambda x: x[0],
110
+ )
111
+ aspect_ratios = []
112
+ with tqdm(total=len(dataset)) as pbar:
113
+ for _i, (img, _) in enumerate(data_loader):
114
+ pbar.update(1)
115
+ height, width = img.shape[-2:]
116
+ aspect_ratio = float(width) / float(height)
117
+ aspect_ratios.append(aspect_ratio)
118
+ return aspect_ratios
119
+
120
+
121
+ def _compute_aspect_ratios_custom_dataset(dataset, indices=None):
122
+ if indices is None:
123
+ indices = range(len(dataset))
124
+ aspect_ratios = []
125
+ for i in indices:
126
+ height, width = dataset.get_height_and_width(i)
127
+ aspect_ratio = float(width) / float(height)
128
+ aspect_ratios.append(aspect_ratio)
129
+ return aspect_ratios
130
+
131
+
132
+ def _compute_aspect_ratios_coco_dataset(dataset, indices=None):
133
+ if indices is None:
134
+ indices = range(len(dataset))
135
+ aspect_ratios = []
136
+ for i in indices:
137
+ img_info = dataset.coco.imgs[dataset.ids[i]]
138
+ aspect_ratio = float(img_info["width"]) / float(img_info["height"])
139
+ aspect_ratios.append(aspect_ratio)
140
+ return aspect_ratios
141
+
142
+
143
+ def _compute_aspect_ratios_voc_dataset(dataset, indices=None):
144
+ if indices is None:
145
+ indices = range(len(dataset))
146
+ aspect_ratios = []
147
+ for i in indices:
148
+ # this doesn't load the data into memory, because PIL loads it lazily
149
+ width, height = Image.open(dataset.images[i]).size
150
+ aspect_ratio = float(width) / float(height)
151
+ aspect_ratios.append(aspect_ratio)
152
+ return aspect_ratios
153
+
154
+
155
+ def _compute_aspect_ratios_subset_dataset(dataset, indices=None):
156
+ if indices is None:
157
+ indices = range(len(dataset))
158
+
159
+ ds_indices = [dataset.indices[i] for i in indices]
160
+ return compute_aspect_ratios(dataset.dataset, ds_indices)
161
+
162
+
163
+ def compute_aspect_ratios(dataset, indices=None):
164
+ if hasattr(dataset, "get_height_and_width"):
165
+ return _compute_aspect_ratios_custom_dataset(dataset, indices)
166
+
167
+ if isinstance(dataset, torchvision.datasets.CocoDetection):
168
+ return _compute_aspect_ratios_coco_dataset(dataset, indices)
169
+
170
+ if isinstance(dataset, torchvision.datasets.VOCDetection):
171
+ return _compute_aspect_ratios_voc_dataset(dataset, indices)
172
+
173
+ if isinstance(dataset, torch.utils.data.Subset):
174
+ return _compute_aspect_ratios_subset_dataset(dataset, indices)
175
+
176
+ # slow path
177
+ return _compute_aspect_ratios_slow(dataset, indices)
178
+
179
+
180
+ def _quantize(x, bins):
181
+ bins = copy.deepcopy(bins)
182
+ bins = sorted(bins)
183
+ quantized = list(map(lambda y: bisect.bisect_right(bins, y), x))
184
+ return quantized
185
+
186
+
187
+ def create_aspect_ratio_groups(dataset, k=0):
188
+ aspect_ratios = compute_aspect_ratios(dataset)
189
+ bins = (2 ** np.linspace(-1, 1, 2 * k + 1)).tolist() if k > 0 else [1.0]
190
+ groups = _quantize(aspect_ratios, bins)
191
+ # count number of elements per group
192
+ counts = np.unique(groups, return_counts=True)[1]
193
+ fbins = [0] + bins + [np.inf]
194
+ print(f"Using {fbins} as bins for aspect ratio quantization")
195
+ print(f"Count of instances per bin: {counts}")
196
+ return groups
DenseMammogram/detection/presets.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import detection.transforms as T
3
+
4
+
5
+ class DetectionPresetTrain:
6
+ def __init__(self, data_augmentation, hflip_prob=0.5, mean=(123.0, 117.0, 104.0)):
7
+ if data_augmentation == "hflip":
8
+ self.transforms = T.Compose(
9
+ [
10
+ T.RandomHorizontalFlip(p=hflip_prob),
11
+ T.PILToTensor(),
12
+ T.ConvertImageDtype(torch.float),
13
+ ]
14
+ )
15
+ elif data_augmentation == "ssd":
16
+ self.transforms = T.Compose(
17
+ [
18
+ T.RandomPhotometricDistort(),
19
+ T.RandomZoomOut(fill=list(mean)),
20
+ T.RandomIoUCrop(),
21
+ T.RandomHorizontalFlip(p=hflip_prob),
22
+ T.PILToTensor(),
23
+ T.ConvertImageDtype(torch.float),
24
+ ]
25
+ )
26
+ elif data_augmentation == "ssdlite":
27
+ self.transforms = T.Compose(
28
+ [
29
+ T.RandomIoUCrop(),
30
+ T.RandomHorizontalFlip(p=hflip_prob),
31
+ T.PILToTensor(),
32
+ T.ConvertImageDtype(torch.float),
33
+ ]
34
+ )
35
+ else:
36
+ raise ValueError(f'Unknown data augmentation policy "{data_augmentation}"')
37
+
38
+ def __call__(self, img, target):
39
+ return self.transforms(img, target)
40
+
41
+
42
+ class DetectionPresetEval:
43
+ def __init__(self):
44
+ self.transforms = T.ToTensor()
45
+
46
+ def __call__(self, img, target):
47
+ return self.transforms(img, target)
DenseMammogram/detection/train.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ r"""PyTorch Detection Training.
2
+
3
+ To run in a multi-gpu environment, use the distributed launcher::
4
+
5
+ python -m torch.distributed.launch --nproc_per_node=$NGPU --use_env \
6
+ train.py ... --world-size $NGPU
7
+
8
+ The default hyperparameters are tuned for training on 8 gpus and 2 images per gpu.
9
+ --lr 0.02 --batch-size 2 --world-size 8
10
+ If you use different number of gpus, the learning rate should be changed to 0.02/8*$NGPU.
11
+
12
+ On top of that, for training Faster/Mask R-CNN, the default hyperparameters are
13
+ --epochs 26 --lr-steps 16 22 --aspect-ratio-group-factor 3
14
+
15
+ Also, if you train Keypoint R-CNN, the default hyperparameters are
16
+ --epochs 46 --lr-steps 36 43 --aspect-ratio-group-factor 3
17
+ Because the number of images is smaller in the person keypoint subset of COCO,
18
+ the number of epochs should be adapted so that we have the same number of iterations.
19
+ """
20
+ import datetime
21
+ import os
22
+ import time
23
+
24
+ import detection.presets
25
+ import torch
26
+ import torch.utils.data
27
+ import torchvision
28
+ import torchvision.models.detection
29
+ import torchvision.models.detection.mask_rcnn
30
+ import detection.utils as utils
31
+ from detection.coco_utils import get_coco, get_coco_kp
32
+ from detection.engine import train_one_epoch, evaluate
33
+ from detection.group_by_aspect_ratio import GroupedBatchSampler, create_aspect_ratio_groups
34
+
35
+
36
+ try:
37
+ from torchvision.prototype import models as PM
38
+ except ImportError:
39
+ PM = None
40
+
41
+
42
+ def get_dataset(name, image_set, transform, data_path):
43
+ paths = {"coco": (data_path, get_coco, 91), "coco_kp": (data_path, get_coco_kp, 2)}
44
+ p, ds_fn, num_classes = paths[name]
45
+
46
+ ds = ds_fn(p, image_set=image_set, transforms=transform)
47
+ return ds, num_classes
48
+
49
+
50
+ def get_transform(train, args):
51
+ if train:
52
+ return presets.DetectionPresetTrain(args.data_augmentation)
53
+ elif not args.weights:
54
+ return presets.DetectionPresetEval()
55
+ else:
56
+ weights = PM.get_weight(args.weights)
57
+ return weights.transforms()
58
+
59
+
60
+ def get_args_parser(add_help=True):
61
+ import argparse
62
+
63
+ parser = argparse.ArgumentParser(description="PyTorch Detection Training", add_help=add_help)
64
+
65
+ parser.add_argument("--data-path", default="/datasets01/COCO/022719/", type=str, help="dataset path")
66
+ parser.add_argument("--dataset", default="coco", type=str, help="dataset name")
67
+ parser.add_argument("--model", default="maskrcnn_resnet50_fpn", type=str, help="model name")
68
+ parser.add_argument("--device", default="cuda", type=str, help="device (Use cuda or cpu Default: cuda)")
69
+ parser.add_argument(
70
+ "-b", "--batch-size", default=2, type=int, help="images per gpu, the total batch size is $NGPU x batch_size"
71
+ )
72
+ parser.add_argument("--epochs", default=26, type=int, metavar="N", help="number of total epochs to run")
73
+ parser.add_argument(
74
+ "-j", "--workers", default=4, type=int, metavar="N", help="number of data loading workers (default: 4)"
75
+ )
76
+ parser.add_argument(
77
+ "--lr",
78
+ default=0.02,
79
+ type=float,
80
+ help="initial learning rate, 0.02 is the default value for training on 8 gpus and 2 images_per_gpu",
81
+ )
82
+ parser.add_argument("--momentum", default=0.9, type=float, metavar="M", help="momentum")
83
+ parser.add_argument(
84
+ "--wd",
85
+ "--weight-decay",
86
+ default=1e-4,
87
+ type=float,
88
+ metavar="W",
89
+ help="weight decay (default: 1e-4)",
90
+ dest="weight_decay",
91
+ )
92
+ parser.add_argument(
93
+ "--lr-scheduler", default="multisteplr", type=str, help="name of lr scheduler (default: multisteplr)"
94
+ )
95
+ parser.add_argument(
96
+ "--lr-step-size", default=8, type=int, help="decrease lr every step-size epochs (multisteplr scheduler only)"
97
+ )
98
+ parser.add_argument(
99
+ "--lr-steps",
100
+ default=[16, 22],
101
+ nargs="+",
102
+ type=int,
103
+ help="decrease lr every step-size epochs (multisteplr scheduler only)",
104
+ )
105
+ parser.add_argument(
106
+ "--lr-gamma", default=0.1, type=float, help="decrease lr by a factor of lr-gamma (multisteplr scheduler only)"
107
+ )
108
+ parser.add_argument("--print-freq", default=20, type=int, help="print frequency")
109
+ parser.add_argument("--output-dir", default=".", type=str, help="path to save outputs")
110
+ parser.add_argument("--resume", default="", type=str, help="path of checkpoint")
111
+ parser.add_argument("--start_epoch", default=0, type=int, help="start epoch")
112
+ parser.add_argument("--aspect-ratio-group-factor", default=3, type=int)
113
+ parser.add_argument("--rpn-score-thresh", default=None, type=float, help="rpn score threshold for faster-rcnn")
114
+ parser.add_argument(
115
+ "--trainable-backbone-layers", default=None, type=int, help="number of trainable layers of backbone"
116
+ )
117
+ parser.add_argument(
118
+ "--data-augmentation", default="hflip", type=str, help="data augmentation policy (default: hflip)"
119
+ )
120
+ parser.add_argument(
121
+ "--sync-bn",
122
+ dest="sync_bn",
123
+ help="Use sync batch norm",
124
+ action="store_true",
125
+ )
126
+ parser.add_argument(
127
+ "--test-only",
128
+ dest="test_only",
129
+ help="Only test the model",
130
+ action="store_true",
131
+ )
132
+ parser.add_argument(
133
+ "--pretrained",
134
+ dest="pretrained",
135
+ help="Use pre-trained models from the modelzoo",
136
+ action="store_true",
137
+ )
138
+
139
+ # distributed training parameters
140
+ parser.add_argument("--world-size", default=1, type=int, help="number of distributed processes")
141
+ parser.add_argument("--dist-url", default="env://", type=str, help="url used to set up distributed training")
142
+
143
+ # Prototype models only
144
+ parser.add_argument("--weights", default=None, type=str, help="the weights enum name to load")
145
+
146
+ # Mixed precision training parameters
147
+ parser.add_argument("--amp", action="store_true", help="Use torch.cuda.amp for mixed precision training")
148
+
149
+ return parser
150
+
151
+
152
+ def main(args):
153
+ if args.weights and PM is None:
154
+ raise ImportError("The prototype module couldn't be found. Please install the latest torchvision nightly.")
155
+ if args.output_dir:
156
+ utils.mkdir(args.output_dir)
157
+
158
+ utils.init_distributed_mode(args)
159
+ print(args)
160
+
161
+ device = torch.device(args.device)
162
+
163
+ # Data loading code
164
+ print("Loading data")
165
+
166
+ dataset, num_classes = get_dataset(args.dataset, "train", get_transform(True, args), args.data_path)
167
+ dataset_test, _ = get_dataset(args.dataset, "val", get_transform(False, args), args.data_path)
168
+
169
+ print("Creating data loaders")
170
+ if args.distributed:
171
+ train_sampler = torch.utils.data.distributed.DistributedSampler(dataset)
172
+ test_sampler = torch.utils.data.distributed.DistributedSampler(dataset_test)
173
+ else:
174
+ train_sampler = torch.utils.data.RandomSampler(dataset)
175
+ test_sampler = torch.utils.data.SequentialSampler(dataset_test)
176
+
177
+ if args.aspect_ratio_group_factor >= 0:
178
+ group_ids = create_aspect_ratio_groups(dataset, k=args.aspect_ratio_group_factor)
179
+ train_batch_sampler = GroupedBatchSampler(train_sampler, group_ids, args.batch_size)
180
+ else:
181
+ train_batch_sampler = torch.utils.data.BatchSampler(train_sampler, args.batch_size, drop_last=True)
182
+
183
+ data_loader = torch.utils.data.DataLoader(
184
+ dataset, batch_sampler=train_batch_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
185
+ )
186
+
187
+ data_loader_test = torch.utils.data.DataLoader(
188
+ dataset_test, batch_size=1, sampler=test_sampler, num_workers=args.workers, collate_fn=utils.collate_fn
189
+ )
190
+
191
+ print("Creating model")
192
+ kwargs = {"trainable_backbone_layers": args.trainable_backbone_layers}
193
+ if "rcnn" in args.model:
194
+ if args.rpn_score_thresh is not None:
195
+ kwargs["rpn_score_thresh"] = args.rpn_score_thresh
196
+ if not args.weights:
197
+ model = torchvision.models.detection.__dict__[args.model](
198
+ pretrained=args.pretrained, num_classes=num_classes, **kwargs
199
+ )
200
+ else:
201
+ model = PM.detection.__dict__[args.model](weights=args.weights, num_classes=num_classes, **kwargs)
202
+ model.to(device)
203
+ if args.distributed and args.sync_bn:
204
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
205
+
206
+ model_without_ddp = model
207
+ if args.distributed:
208
+ model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[args.gpu])
209
+ model_without_ddp = model.module
210
+
211
+ params = [p for p in model.parameters() if p.requires_grad]
212
+ optimizer = torch.optim.SGD(params, lr=args.lr, momentum=args.momentum, weight_decay=args.weight_decay)
213
+
214
+ scaler = torch.cuda.amp.GradScaler() if args.amp else None
215
+
216
+ args.lr_scheduler = args.lr_scheduler.lower()
217
+ if args.lr_scheduler == "multisteplr":
218
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=args.lr_steps, gamma=args.lr_gamma)
219
+ elif args.lr_scheduler == "cosineannealinglr":
220
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
221
+ else:
222
+ raise RuntimeError(
223
+ f"Invalid lr scheduler '{args.lr_scheduler}'. Only MultiStepLR and CosineAnnealingLR are supported."
224
+ )
225
+
226
+ if args.resume:
227
+ checkpoint = torch.load(args.resume, map_location="cpu")
228
+ model_without_ddp.load_state_dict(checkpoint["model"])
229
+ optimizer.load_state_dict(checkpoint["optimizer"])
230
+ lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
231
+ args.start_epoch = checkpoint["epoch"] + 1
232
+ if args.amp:
233
+ scaler.load_state_dict(checkpoint["scaler"])
234
+
235
+ if args.test_only:
236
+ evaluate(model, data_loader_test, device=device)
237
+ return
238
+
239
+ print("Start training")
240
+ start_time = time.time()
241
+ for epoch in range(args.start_epoch, args.epochs):
242
+ if args.distributed:
243
+ train_sampler.set_epoch(epoch)
244
+ train_one_epoch(model, optimizer, data_loader, device, epoch, args.print_freq, scaler)
245
+ lr_scheduler.step()
246
+ if args.output_dir:
247
+ checkpoint = {
248
+ "model": model_without_ddp.state_dict(),
249
+ "optimizer": optimizer.state_dict(),
250
+ "lr_scheduler": lr_scheduler.state_dict(),
251
+ "args": args,
252
+ "epoch": epoch,
253
+ }
254
+ if args.amp:
255
+ checkpoint["scaler"] = scaler.state_dict()
256
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, f"model_{epoch}.pth"))
257
+ utils.save_on_master(checkpoint, os.path.join(args.output_dir, "checkpoint.pth"))
258
+
259
+ # evaluate after every epoch
260
+ evaluate(model, data_loader_test, device=device)
261
+
262
+ total_time = time.time() - start_time
263
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
264
+ print(f"Training time {total_time_str}")
265
+
266
+
267
+ if __name__ == "__main__":
268
+ args = get_args_parser().parse_args()
269
+ main(args)
DenseMammogram/detection/transforms.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Dict, Optional
2
+
3
+ import torch
4
+ import torchvision
5
+ from torch import nn, Tensor
6
+ from torchvision.transforms import functional as F
7
+ from torchvision.transforms import transforms as T
8
+
9
+
10
+ def _flip_coco_person_keypoints(kps, width):
11
+ flip_inds = [0, 2, 1, 4, 3, 6, 5, 8, 7, 10, 9, 12, 11, 14, 13, 16, 15]
12
+ flipped_data = kps[:, flip_inds]
13
+ flipped_data[..., 0] = width - flipped_data[..., 0]
14
+ # Maintain COCO convention that if visibility == 0, then x, y = 0
15
+ inds = flipped_data[..., 2] == 0
16
+ flipped_data[inds] = 0
17
+ return flipped_data
18
+
19
+
20
+ class Compose:
21
+ def __init__(self, transforms):
22
+ self.transforms = transforms
23
+
24
+ def __call__(self, image, target):
25
+ for t in self.transforms:
26
+ image, target = t(image, target)
27
+ return image, target
28
+
29
+
30
+ class RandomHorizontalFlip(T.RandomHorizontalFlip):
31
+ def forward(
32
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
33
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
34
+ if torch.rand(1) < self.p:
35
+ image = F.hflip(image)
36
+ if target is not None:
37
+ width, _ = F.get_image_size(image)
38
+ target["boxes"][:, [0, 2]] = width - target["boxes"][:, [2, 0]]
39
+ if "masks" in target:
40
+ target["masks"] = target["masks"].flip(-1)
41
+ if "keypoints" in target:
42
+ keypoints = target["keypoints"]
43
+ keypoints = _flip_coco_person_keypoints(keypoints, width)
44
+ target["keypoints"] = keypoints
45
+ return image, target
46
+
47
+
48
+ class ToTensor(nn.Module):
49
+ def forward(
50
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
51
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
52
+ image = F.pil_to_tensor(image)
53
+ image = F.convert_image_dtype(image)
54
+ return image, target
55
+
56
+
57
+ class PILToTensor(nn.Module):
58
+ def forward(
59
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
60
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
61
+ image = F.pil_to_tensor(image)
62
+ return image, target
63
+
64
+
65
+ class ConvertImageDtype(nn.Module):
66
+ def __init__(self, dtype: torch.dtype) -> None:
67
+ super().__init__()
68
+ self.dtype = dtype
69
+
70
+ def forward(
71
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
72
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
73
+ image = F.convert_image_dtype(image, self.dtype)
74
+ return image, target
75
+
76
+
77
+ class RandomIoUCrop(nn.Module):
78
+ def __init__(
79
+ self,
80
+ min_scale: float = 0.3,
81
+ max_scale: float = 1.0,
82
+ min_aspect_ratio: float = 0.5,
83
+ max_aspect_ratio: float = 2.0,
84
+ sampler_options: Optional[List[float]] = None,
85
+ trials: int = 40,
86
+ ):
87
+ super().__init__()
88
+ # Configuration similar to https://github.com/weiliu89/caffe/blob/ssd/examples/ssd/ssd_coco.py#L89-L174
89
+ self.min_scale = min_scale
90
+ self.max_scale = max_scale
91
+ self.min_aspect_ratio = min_aspect_ratio
92
+ self.max_aspect_ratio = max_aspect_ratio
93
+ if sampler_options is None:
94
+ sampler_options = [0.0, 0.1, 0.3, 0.5, 0.7, 0.9, 1.0]
95
+ self.options = sampler_options
96
+ self.trials = trials
97
+
98
+ def forward(
99
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
100
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
101
+ if target is None:
102
+ raise ValueError("The targets can't be None for this transform.")
103
+
104
+ if isinstance(image, torch.Tensor):
105
+ if image.ndimension() not in {2, 3}:
106
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
107
+ elif image.ndimension() == 2:
108
+ image = image.unsqueeze(0)
109
+
110
+ orig_w, orig_h = F.get_image_size(image)
111
+
112
+ while True:
113
+ # sample an option
114
+ idx = int(torch.randint(low=0, high=len(self.options), size=(1,)))
115
+ min_jaccard_overlap = self.options[idx]
116
+ if min_jaccard_overlap >= 1.0: # a value larger than 1 encodes the leave as-is option
117
+ return image, target
118
+
119
+ for _ in range(self.trials):
120
+ # check the aspect ratio limitations
121
+ r = self.min_scale + (self.max_scale - self.min_scale) * torch.rand(2)
122
+ new_w = int(orig_w * r[0])
123
+ new_h = int(orig_h * r[1])
124
+ aspect_ratio = new_w / new_h
125
+ if not (self.min_aspect_ratio <= aspect_ratio <= self.max_aspect_ratio):
126
+ continue
127
+
128
+ # check for 0 area crops
129
+ r = torch.rand(2)
130
+ left = int((orig_w - new_w) * r[0])
131
+ top = int((orig_h - new_h) * r[1])
132
+ right = left + new_w
133
+ bottom = top + new_h
134
+ if left == right or top == bottom:
135
+ continue
136
+
137
+ # check for any valid boxes with centers within the crop area
138
+ cx = 0.5 * (target["boxes"][:, 0] + target["boxes"][:, 2])
139
+ cy = 0.5 * (target["boxes"][:, 1] + target["boxes"][:, 3])
140
+ is_within_crop_area = (left < cx) & (cx < right) & (top < cy) & (cy < bottom)
141
+ if not is_within_crop_area.any():
142
+ continue
143
+
144
+ # check at least 1 box with jaccard limitations
145
+ boxes = target["boxes"][is_within_crop_area]
146
+ ious = torchvision.ops.boxes.box_iou(
147
+ boxes, torch.tensor([[left, top, right, bottom]], dtype=boxes.dtype, device=boxes.device)
148
+ )
149
+ if ious.max() < min_jaccard_overlap:
150
+ continue
151
+
152
+ # keep only valid boxes and perform cropping
153
+ target["boxes"] = boxes
154
+ target["labels"] = target["labels"][is_within_crop_area]
155
+ target["boxes"][:, 0::2] -= left
156
+ target["boxes"][:, 1::2] -= top
157
+ target["boxes"][:, 0::2].clamp_(min=0, max=new_w)
158
+ target["boxes"][:, 1::2].clamp_(min=0, max=new_h)
159
+ image = F.crop(image, top, left, new_h, new_w)
160
+
161
+ return image, target
162
+
163
+
164
+ class RandomZoomOut(nn.Module):
165
+ def __init__(
166
+ self, fill: Optional[List[float]] = None, side_range: Tuple[float, float] = (1.0, 4.0), p: float = 0.5
167
+ ):
168
+ super().__init__()
169
+ if fill is None:
170
+ fill = [0.0, 0.0, 0.0]
171
+ self.fill = fill
172
+ self.side_range = side_range
173
+ if side_range[0] < 1.0 or side_range[0] > side_range[1]:
174
+ raise ValueError(f"Invalid canvas side range provided {side_range}.")
175
+ self.p = p
176
+
177
+ @torch.jit.unused
178
+ def _get_fill_value(self, is_pil):
179
+ # type: (bool) -> int
180
+ # We fake the type to make it work on JIT
181
+ return tuple(int(x) for x in self.fill) if is_pil else 0
182
+
183
+ def forward(
184
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
185
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
186
+ if isinstance(image, torch.Tensor):
187
+ if image.ndimension() not in {2, 3}:
188
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
189
+ elif image.ndimension() == 2:
190
+ image = image.unsqueeze(0)
191
+
192
+ if torch.rand(1) < self.p:
193
+ return image, target
194
+
195
+ orig_w, orig_h = F.get_image_size(image)
196
+
197
+ r = self.side_range[0] + torch.rand(1) * (self.side_range[1] - self.side_range[0])
198
+ canvas_width = int(orig_w * r)
199
+ canvas_height = int(orig_h * r)
200
+
201
+ r = torch.rand(2)
202
+ left = int((canvas_width - orig_w) * r[0])
203
+ top = int((canvas_height - orig_h) * r[1])
204
+ right = canvas_width - (left + orig_w)
205
+ bottom = canvas_height - (top + orig_h)
206
+
207
+ if torch.jit.is_scripting():
208
+ fill = 0
209
+ else:
210
+ fill = self._get_fill_value(F._is_pil_image(image))
211
+
212
+ image = F.pad(image, [left, top, right, bottom], fill=fill)
213
+ if isinstance(image, torch.Tensor):
214
+ v = torch.tensor(self.fill, device=image.device, dtype=image.dtype).view(-1, 1, 1)
215
+ image[..., :top, :] = image[..., :, :left] = image[..., (top + orig_h) :, :] = image[
216
+ ..., :, (left + orig_w) :
217
+ ] = v
218
+
219
+ if target is not None:
220
+ target["boxes"][:, 0::2] += left
221
+ target["boxes"][:, 1::2] += top
222
+
223
+ return image, target
224
+
225
+
226
+ class RandomPhotometricDistort(nn.Module):
227
+ def __init__(
228
+ self,
229
+ contrast: Tuple[float] = (0.5, 1.5),
230
+ saturation: Tuple[float] = (0.5, 1.5),
231
+ hue: Tuple[float] = (-0.05, 0.05),
232
+ brightness: Tuple[float] = (0.875, 1.125),
233
+ p: float = 0.5,
234
+ ):
235
+ super().__init__()
236
+ self._brightness = T.ColorJitter(brightness=brightness)
237
+ self._contrast = T.ColorJitter(contrast=contrast)
238
+ self._hue = T.ColorJitter(hue=hue)
239
+ self._saturation = T.ColorJitter(saturation=saturation)
240
+ self.p = p
241
+
242
+ def forward(
243
+ self, image: Tensor, target: Optional[Dict[str, Tensor]] = None
244
+ ) -> Tuple[Tensor, Optional[Dict[str, Tensor]]]:
245
+ if isinstance(image, torch.Tensor):
246
+ if image.ndimension() not in {2, 3}:
247
+ raise ValueError(f"image should be 2/3 dimensional. Got {image.ndimension()} dimensions.")
248
+ elif image.ndimension() == 2:
249
+ image = image.unsqueeze(0)
250
+
251
+ r = torch.rand(7)
252
+
253
+ if r[0] < self.p:
254
+ image = self._brightness(image)
255
+
256
+ contrast_before = r[1] < 0.5
257
+ if contrast_before:
258
+ if r[2] < self.p:
259
+ image = self._contrast(image)
260
+
261
+ if r[3] < self.p:
262
+ image = self._saturation(image)
263
+
264
+ if r[4] < self.p:
265
+ image = self._hue(image)
266
+
267
+ if not contrast_before:
268
+ if r[5] < self.p:
269
+ image = self._contrast(image)
270
+
271
+ if r[6] < self.p:
272
+ channels = F.get_image_num_channels(image)
273
+ permutation = torch.randperm(channels)
274
+
275
+ is_pil = F._is_pil_image(image)
276
+ if is_pil:
277
+ image = F.pil_to_tensor(image)
278
+ image = F.convert_image_dtype(image)
279
+ image = image[..., permutation, :, :]
280
+ if is_pil:
281
+ image = F.to_pil_image(image)
282
+
283
+ return image, target
DenseMammogram/detection/utils.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import datetime
2
+ import errno
3
+ import os
4
+ import time
5
+ from collections import defaultdict, deque
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+
10
+
11
+ class SmoothedValue:
12
+ """Track a series of values and provide access to smoothed values over a
13
+ window or the global series average.
14
+ """
15
+
16
+ def __init__(self, window_size=20, fmt=None):
17
+ if fmt is None:
18
+ fmt = "{median:.4f} ({global_avg:.4f})"
19
+ self.deque = deque(maxlen=window_size)
20
+ self.total = 0.0
21
+ self.count = 0
22
+ self.fmt = fmt
23
+
24
+ def update(self, value, n=1):
25
+ self.deque.append(value)
26
+ self.count += n
27
+ self.total += value * n
28
+
29
+ def synchronize_between_processes(self):
30
+ """
31
+ Warning: does not synchronize the deque!
32
+ """
33
+ if not is_dist_avail_and_initialized():
34
+ return
35
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
36
+ dist.barrier()
37
+ dist.all_reduce(t)
38
+ t = t.tolist()
39
+ self.count = int(t[0])
40
+ self.total = t[1]
41
+
42
+ @property
43
+ def median(self):
44
+ d = torch.tensor(list(self.deque))
45
+ return d.median().item()
46
+
47
+ @property
48
+ def avg(self):
49
+ d = torch.tensor(list(self.deque), dtype=torch.float32)
50
+ return d.mean().item()
51
+
52
+ @property
53
+ def global_avg(self):
54
+ return self.total / self.count
55
+
56
+ @property
57
+ def max(self):
58
+ return max(self.deque)
59
+
60
+ @property
61
+ def value(self):
62
+ return self.deque[-1]
63
+
64
+ def __str__(self):
65
+ return self.fmt.format(
66
+ median=self.median, avg=self.avg, global_avg=self.global_avg, max=self.max, value=self.value
67
+ )
68
+
69
+
70
+ def all_gather(data):
71
+ """
72
+ Run all_gather on arbitrary picklable data (not necessarily tensors)
73
+ Args:
74
+ data: any picklable object
75
+ Returns:
76
+ list[data]: list of data gathered from each rank
77
+ """
78
+ world_size = get_world_size()
79
+ if world_size == 1:
80
+ return [data]
81
+ data_list = [None] * world_size
82
+ dist.all_gather_object(data_list, data)
83
+ return data_list
84
+
85
+
86
+ def reduce_dict(input_dict, average=True):
87
+ """
88
+ Args:
89
+ input_dict (dict): all the values will be reduced
90
+ average (bool): whether to do average or sum
91
+ Reduce the values in the dictionary from all processes so that all processes
92
+ have the averaged results. Returns a dict with the same fields as
93
+ input_dict, after reduction.
94
+ """
95
+ world_size = get_world_size()
96
+ if world_size < 2:
97
+ return input_dict
98
+ with torch.inference_mode():
99
+ names = []
100
+ values = []
101
+ # sort the keys so that they are consistent across processes
102
+ for k in sorted(input_dict.keys()):
103
+ names.append(k)
104
+ values.append(input_dict[k])
105
+ values = torch.stack(values, dim=0)
106
+ dist.all_reduce(values)
107
+ if average:
108
+ values /= world_size
109
+ reduced_dict = {k: v for k, v in zip(names, values)}
110
+ return reduced_dict
111
+
112
+
113
+ class MetricLogger:
114
+ def __init__(self, delimiter="\t"):
115
+ self.meters = defaultdict(SmoothedValue)
116
+ self.delimiter = delimiter
117
+
118
+ def update(self, **kwargs):
119
+ for k, v in kwargs.items():
120
+ if isinstance(v, torch.Tensor):
121
+ v = v.item()
122
+ assert isinstance(v, (float, int))
123
+ self.meters[k].update(v)
124
+
125
+ def __getattr__(self, attr):
126
+ if attr in self.meters:
127
+ return self.meters[attr]
128
+ if attr in self.__dict__:
129
+ return self.__dict__[attr]
130
+ raise AttributeError(f"'{type(self).__name__}' object has no attribute '{attr}'")
131
+
132
+ def __str__(self):
133
+ loss_str = []
134
+ for name, meter in self.meters.items():
135
+ loss_str.append(f"{name}: {str(meter)}")
136
+ return self.delimiter.join(loss_str)
137
+
138
+ def synchronize_between_processes(self):
139
+ for meter in self.meters.values():
140
+ meter.synchronize_between_processes()
141
+
142
+ def add_meter(self, name, meter):
143
+ self.meters[name] = meter
144
+
145
+ def log_every(self, iterable, print_freq, header=None):
146
+ i = 0
147
+ if not header:
148
+ header = ""
149
+ start_time = time.time()
150
+ end = time.time()
151
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
152
+ data_time = SmoothedValue(fmt="{avg:.4f}")
153
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
154
+ if torch.cuda.is_available():
155
+ log_msg = self.delimiter.join(
156
+ [
157
+ header,
158
+ "[{0" + space_fmt + "}/{1}]",
159
+ "eta: {eta}",
160
+ "{meters}",
161
+ "time: {time}",
162
+ "data: {data}",
163
+ "max mem: {memory:.0f}",
164
+ ]
165
+ )
166
+ else:
167
+ log_msg = self.delimiter.join(
168
+ [header, "[{0" + space_fmt + "}/{1}]", "eta: {eta}", "{meters}", "time: {time}", "data: {data}"]
169
+ )
170
+ MB = 1024.0 * 1024.0
171
+ for obj in iterable:
172
+ data_time.update(time.time() - end)
173
+ yield obj
174
+ iter_time.update(time.time() - end)
175
+ if i % print_freq == 0 or i == len(iterable) - 1:
176
+ eta_seconds = iter_time.global_avg * (len(iterable) - i)
177
+ eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
178
+ if torch.cuda.is_available():
179
+ print(
180
+ log_msg.format(
181
+ i,
182
+ len(iterable),
183
+ eta=eta_string,
184
+ meters=str(self),
185
+ time=str(iter_time),
186
+ data=str(data_time),
187
+ memory=torch.cuda.max_memory_allocated() / MB,
188
+ )
189
+ )
190
+ else:
191
+ print(
192
+ log_msg.format(
193
+ i, len(iterable), eta=eta_string, meters=str(self), time=str(iter_time), data=str(data_time)
194
+ )
195
+ )
196
+ i += 1
197
+ end = time.time()
198
+ total_time = time.time() - start_time
199
+ total_time_str = str(datetime.timedelta(seconds=int(total_time)))
200
+ print(f"{header} Total time: {total_time_str} ({total_time / len(iterable):.4f} s / it)")
201
+
202
+
203
+ def collate_fn(batch):
204
+ return tuple(zip(*batch))
205
+
206
+
207
+ def mkdir(path):
208
+ try:
209
+ os.makedirs(path)
210
+ except OSError as e:
211
+ if e.errno != errno.EEXIST:
212
+ raise
213
+
214
+
215
+ def setup_for_distributed(is_master):
216
+ """
217
+ This function disables printing when not in master process
218
+ """
219
+ import builtins as __builtin__
220
+
221
+ builtin_print = __builtin__.print
222
+
223
+ def print(*args, **kwargs):
224
+ force = kwargs.pop("force", False)
225
+ if is_master or force:
226
+ builtin_print(*args, **kwargs)
227
+
228
+ __builtin__.print = print
229
+
230
+
231
+ def is_dist_avail_and_initialized():
232
+ if not dist.is_available():
233
+ return False
234
+ if not dist.is_initialized():
235
+ return False
236
+ return True
237
+
238
+
239
+ def get_world_size():
240
+ if not is_dist_avail_and_initialized():
241
+ return 1
242
+ return dist.get_world_size()
243
+
244
+
245
+ def get_rank():
246
+ if not is_dist_avail_and_initialized():
247
+ return 0
248
+ return dist.get_rank()
249
+
250
+
251
+ def is_main_process():
252
+ return get_rank() == 0
253
+
254
+
255
+ def save_on_master(*args, **kwargs):
256
+ if is_main_process():
257
+ torch.save(*args, **kwargs)
258
+
259
+
260
+ def init_distributed_mode(args):
261
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
262
+ args.rank = int(os.environ["RANK"])
263
+ args.world_size = int(os.environ["WORLD_SIZE"])
264
+ args.gpu = int(os.environ["LOCAL_RANK"])
265
+ elif "SLURM_PROCID" in os.environ:
266
+ args.rank = int(os.environ["SLURM_PROCID"])
267
+ args.gpu = args.rank % torch.cuda.device_count()
268
+ else:
269
+ print("Not using distributed mode")
270
+ args.distributed = False
271
+ return
272
+
273
+ args.distributed = True
274
+
275
+ torch.cuda.set_device(args.gpu)
276
+ args.dist_backend = "nccl"
277
+ print(f"| distributed init (rank {args.rank}): {args.dist_url}", flush=True)
278
+ torch.distributed.init_process_group(
279
+ backend=args.dist_backend, init_method=args.dist_url, world_size=args.world_size, rank=args.rank
280
+ )
281
+ torch.distributed.barrier()
282
+ setup_for_distributed(args.rank == 0)
DenseMammogram/ensemble_boxes/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3
+
4
+ from .ensemble_boxes_wbf import weighted_boxes_fusion
5
+ from .ensemble_boxes_nmw import non_maximum_weighted
6
+ from .ensemble_boxes_nms import nms_method
7
+ from .ensemble_boxes_nms import nms
8
+ from .ensemble_boxes_nms import soft_nms
9
+ from .ensemble_boxes_wbf_3d import weighted_boxes_fusion_3d
DenseMammogram/ensemble_boxes/ensemble_boxes_nms.py ADDED
@@ -0,0 +1,249 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3
+
4
+ import numpy as np
5
+ from numba import jit
6
+
7
+
8
+ def prepare_boxes(boxes, scores, labels):
9
+ result_boxes = boxes.copy()
10
+
11
+ cond = (result_boxes < 0)
12
+ cond_sum = cond.astype(np.int32).sum()
13
+ if cond_sum > 0:
14
+ print('Warning. Fixed {} boxes coordinates < 0'.format(cond_sum))
15
+ result_boxes[cond] = 0
16
+
17
+ cond = (result_boxes > 1)
18
+ cond_sum = cond.astype(np.int32).sum()
19
+ if cond_sum > 0:
20
+ print('Warning. Fixed {} boxes coordinates > 1. Check that your boxes was normalized at [0, 1]'.format(cond_sum))
21
+ result_boxes[cond] = 1
22
+
23
+ boxes1 = result_boxes.copy()
24
+ result_boxes[:, 0] = np.min(boxes1[:, [0, 2]], axis=1)
25
+ result_boxes[:, 2] = np.max(boxes1[:, [0, 2]], axis=1)
26
+ result_boxes[:, 1] = np.min(boxes1[:, [1, 3]], axis=1)
27
+ result_boxes[:, 3] = np.max(boxes1[:, [1, 3]], axis=1)
28
+
29
+ area = (result_boxes[:, 2] - result_boxes[:, 0]) * (result_boxes[:, 3] - result_boxes[:, 1])
30
+ cond = (area == 0)
31
+ cond_sum = cond.astype(np.int32).sum()
32
+ if cond_sum > 0:
33
+ print('Warning. Removed {} boxes with zero area!'.format(cond_sum))
34
+ result_boxes = result_boxes[area > 0]
35
+ scores = scores[area > 0]
36
+ labels = labels[area > 0]
37
+
38
+ return result_boxes, scores, labels
39
+
40
+
41
+ def cpu_soft_nms_float(dets, sc, Nt, sigma, thresh, method):
42
+ """
43
+ Based on: https://github.com/DocF/Soft-NMS/blob/master/soft_nms.py
44
+ It's different from original soft-NMS because we have float coordinates on range [0; 1]
45
+
46
+ :param dets: boxes format [x1, y1, x2, y2]
47
+ :param sc: scores for boxes
48
+ :param Nt: required iou
49
+ :param sigma:
50
+ :param thresh:
51
+ :param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
52
+ :return: index of boxes to keep
53
+ """
54
+
55
+ # indexes concatenate boxes with the last column
56
+ N = dets.shape[0]
57
+ indexes = np.array([np.arange(N)])
58
+ dets = np.concatenate((dets, indexes.T), axis=1)
59
+
60
+ # the order of boxes coordinate is [y1, x1, y2, x2]
61
+ y1 = dets[:, 1]
62
+ x1 = dets[:, 0]
63
+ y2 = dets[:, 3]
64
+ x2 = dets[:, 2]
65
+ scores = sc
66
+ areas = (x2 - x1) * (y2 - y1)
67
+
68
+ for i in range(N):
69
+ # intermediate parameters for later parameters exchange
70
+ tBD = dets[i, :].copy()
71
+ tscore = scores[i].copy()
72
+ tarea = areas[i].copy()
73
+ pos = i + 1
74
+
75
+ #
76
+ if i != N - 1:
77
+ maxscore = np.max(scores[pos:], axis=0)
78
+ maxpos = np.argmax(scores[pos:], axis=0)
79
+ else:
80
+ maxscore = scores[-1]
81
+ maxpos = 0
82
+ if tscore < maxscore:
83
+ dets[i, :] = dets[maxpos + i + 1, :]
84
+ dets[maxpos + i + 1, :] = tBD
85
+ tBD = dets[i, :]
86
+
87
+ scores[i] = scores[maxpos + i + 1]
88
+ scores[maxpos + i + 1] = tscore
89
+ tscore = scores[i]
90
+
91
+ areas[i] = areas[maxpos + i + 1]
92
+ areas[maxpos + i + 1] = tarea
93
+ tarea = areas[i]
94
+
95
+ # IoU calculate
96
+ xx1 = np.maximum(dets[i, 1], dets[pos:, 1])
97
+ yy1 = np.maximum(dets[i, 0], dets[pos:, 0])
98
+ xx2 = np.minimum(dets[i, 3], dets[pos:, 3])
99
+ yy2 = np.minimum(dets[i, 2], dets[pos:, 2])
100
+
101
+ w = np.maximum(0.0, xx2 - xx1)
102
+ h = np.maximum(0.0, yy2 - yy1)
103
+ inter = w * h
104
+ ovr = inter / (areas[i] + areas[pos:] - inter)
105
+
106
+ # Three methods: 1.linear 2.gaussian 3.original NMS
107
+ if method == 1: # linear
108
+ weight = np.ones(ovr.shape)
109
+ weight[ovr > Nt] = weight[ovr > Nt] - ovr[ovr > Nt]
110
+ elif method == 2: # gaussian
111
+ weight = np.exp(-(ovr * ovr) / sigma)
112
+ else: # original NMS
113
+ weight = np.ones(ovr.shape)
114
+ weight[ovr > Nt] = 0
115
+
116
+ scores[pos:] = weight * scores[pos:]
117
+
118
+ # select the boxes and keep the corresponding indexes
119
+ inds = dets[:, 4][scores > thresh]
120
+ keep = inds.astype(int)
121
+ return keep
122
+
123
+
124
+ @jit(nopython=True)
125
+ def nms_float_fast(dets, scores, thresh):
126
+ """
127
+ # It's different from original nms because we have float coordinates on range [0; 1]
128
+ :param dets: numpy array of boxes with shape: (N, 5). Order: x1, y1, x2, y2, score. All variables in range [0; 1]
129
+ :param thresh: IoU value for boxes
130
+ :return: index of boxes to keep
131
+ """
132
+ x1 = dets[:, 0]
133
+ y1 = dets[:, 1]
134
+ x2 = dets[:, 2]
135
+ y2 = dets[:, 3]
136
+
137
+ areas = (x2 - x1) * (y2 - y1)
138
+ order = scores.argsort()[::-1]
139
+
140
+ keep = []
141
+ while order.size > 0:
142
+ i = order[0]
143
+ keep.append(i)
144
+ xx1 = np.maximum(x1[i], x1[order[1:]])
145
+ yy1 = np.maximum(y1[i], y1[order[1:]])
146
+ xx2 = np.minimum(x2[i], x2[order[1:]])
147
+ yy2 = np.minimum(y2[i], y2[order[1:]])
148
+
149
+ w = np.maximum(0.0, xx2 - xx1)
150
+ h = np.maximum(0.0, yy2 - yy1)
151
+ inter = w * h
152
+ ovr = inter / (areas[i] + areas[order[1:]] - inter)
153
+ inds = np.where(ovr <= thresh)[0]
154
+ order = order[inds + 1]
155
+
156
+ return keep
157
+
158
+
159
+ def nms_method(boxes, scores, labels, method=3, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
160
+ """
161
+ :param boxes: list of boxes predictions from each model, each box is 4 numbers.
162
+ It has 3 dimensions (models_number, model_preds, 4)
163
+ Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
164
+ :param scores: list of scores for each model
165
+ :param labels: list of labels for each model
166
+ :param method: 1 - linear soft-NMS, 2 - gaussian soft-NMS, 3 - standard NMS
167
+ :param iou_thr: IoU value for boxes to be a match
168
+ :param sigma: Sigma value for SoftNMS
169
+ :param thresh: threshold for boxes to keep (important for SoftNMS)
170
+ :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
171
+
172
+ :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
173
+ :return: scores: confidence scores
174
+ :return: labels: boxes labels
175
+ """
176
+
177
+ # If weights are specified
178
+ if weights is not None:
179
+ if len(boxes) != len(weights):
180
+ print('Incorrect number of weights: {}. Must be: {}. Skip it'.format(len(weights), len(boxes)))
181
+ else:
182
+ weights = np.array(weights)
183
+ for i in range(len(weights)):
184
+ scores[i] = (np.array(scores[i]) * weights[i]) / weights.sum()
185
+
186
+ # We concatenate everything
187
+ boxes = np.concatenate(boxes)
188
+ scores = np.concatenate(scores)
189
+ labels = np.concatenate(labels)
190
+
191
+ # Fix coordinates and removed zero area boxes
192
+ boxes, scores, labels = prepare_boxes(boxes, scores, labels)
193
+
194
+ # Run NMS independently for each label
195
+ unique_labels = np.unique(labels)
196
+ final_boxes = []
197
+ final_scores = []
198
+ final_labels = []
199
+ for l in unique_labels:
200
+ condition = (labels == l)
201
+ boxes_by_label = boxes[condition]
202
+ scores_by_label = scores[condition]
203
+ labels_by_label = np.array([l] * len(boxes_by_label))
204
+
205
+ if method != 3:
206
+ keep = cpu_soft_nms_float(boxes_by_label.copy(), scores_by_label.copy(), Nt=iou_thr, sigma=sigma, thresh=thresh, method=method)
207
+ else:
208
+ # Use faster function
209
+ keep = nms_float_fast(boxes_by_label, scores_by_label, thresh=iou_thr)
210
+
211
+ final_boxes.append(boxes_by_label[keep])
212
+ final_scores.append(scores_by_label[keep])
213
+ final_labels.append(labels_by_label[keep])
214
+ final_boxes = np.concatenate(final_boxes)
215
+ final_scores = np.concatenate(final_scores)
216
+ final_labels = np.concatenate(final_labels)
217
+
218
+ return final_boxes, final_scores, final_labels
219
+
220
+
221
+ def nms(boxes, scores, labels, iou_thr=0.5, weights=None):
222
+ """
223
+ Short call for standard NMS
224
+
225
+ :param boxes:
226
+ :param scores:
227
+ :param labels:
228
+ :param iou_thr:
229
+ :param weights:
230
+ :return:
231
+ """
232
+ return nms_method(boxes, scores, labels, method=3, iou_thr=iou_thr, weights=weights)
233
+
234
+
235
+ def soft_nms(boxes, scores, labels, method=2, iou_thr=0.5, sigma=0.5, thresh=0.001, weights=None):
236
+ """
237
+ Short call for Soft-NMS
238
+
239
+ :param boxes:
240
+ :param scores:
241
+ :param labels:
242
+ :param method:
243
+ :param iou_thr:
244
+ :param sigma:
245
+ :param thresh:
246
+ :param weights:
247
+ :return:
248
+ """
249
+ return nms_method(boxes, scores, labels, method=method, iou_thr=iou_thr, sigma=sigma, thresh=thresh, weights=weights)
DenseMammogram/ensemble_boxes/ensemble_boxes_nmw.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3
+
4
+ """
5
+ Method described in:
6
+ CAD: Scale Invariant Framework for Real-Time Object Detection
7
+ http://openaccess.thecvf.com/content_ICCV_2017_workshops/papers/w14/Zhou_CAD_Scale_Invariant_ICCV_2017_paper.pdf
8
+ """
9
+
10
+ import warnings
11
+ import numpy as np
12
+ from numba import jit
13
+
14
+
15
+ @jit(nopython=True)
16
+ def bb_intersection_over_union(A, B):
17
+ xA = max(A[0], B[0])
18
+ yA = max(A[1], B[1])
19
+ xB = min(A[2], B[2])
20
+ yB = min(A[3], B[3])
21
+
22
+ # compute the area of intersection rectangle
23
+ interArea = max(0, xB - xA) * max(0, yB - yA)
24
+
25
+ if interArea == 0:
26
+ return 0.0
27
+
28
+ # compute the area of both the prediction and ground-truth rectangles
29
+ boxAArea = (A[2] - A[0]) * (A[3] - A[1])
30
+ boxBArea = (B[2] - B[0]) * (B[3] - B[1])
31
+
32
+ iou = interArea / float(boxAArea + boxBArea - interArea)
33
+ return iou
34
+
35
+
36
+ def prefilter_boxes(boxes, scores, labels, weights, thr):
37
+ # Create dict with boxes stored by its label
38
+ new_boxes = dict()
39
+ for t in range(len(boxes)):
40
+
41
+ if len(boxes[t]) != len(scores[t]):
42
+ print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]),
43
+ len(scores[t])))
44
+ exit()
45
+
46
+ if len(boxes[t]) != len(labels[t]):
47
+ print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]),
48
+ len(labels[t])))
49
+ exit()
50
+
51
+ for j in range(len(boxes[t])):
52
+ score = scores[t][j]
53
+ if score < thr:
54
+ continue
55
+ label = int(labels[t][j])
56
+ box_part = boxes[t][j]
57
+ x1 = float(box_part[0])
58
+ y1 = float(box_part[1])
59
+ x2 = float(box_part[2])
60
+ y2 = float(box_part[3])
61
+
62
+ # Box data checks
63
+ if x2 < x1:
64
+ warnings.warn('X2 < X1 value in box. Swap them.')
65
+ x1, x2 = x2, x1
66
+ if y2 < y1:
67
+ warnings.warn('Y2 < Y1 value in box. Swap them.')
68
+ y1, y2 = y2, y1
69
+ if x1 < 0:
70
+ warnings.warn('X1 < 0 in box. Set it to 0.')
71
+ x1 = 0
72
+ if x1 > 1:
73
+ warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
74
+ x1 = 1
75
+ if x2 < 0:
76
+ warnings.warn('X2 < 0 in box. Set it to 0.')
77
+ x2 = 0
78
+ if x2 > 1:
79
+ warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
80
+ x2 = 1
81
+ if y1 < 0:
82
+ warnings.warn('Y1 < 0 in box. Set it to 0.')
83
+ y1 = 0
84
+ if y1 > 1:
85
+ warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
86
+ y1 = 1
87
+ if y2 < 0:
88
+ warnings.warn('Y2 < 0 in box. Set it to 0.')
89
+ y2 = 0
90
+ if y2 > 1:
91
+ warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
92
+ y2 = 1
93
+ if (x2 - x1) * (y2 - y1) == 0.0:
94
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
95
+ continue
96
+
97
+ b = [int(label), float(score) * weights[t], x1, y1, x2, y2]
98
+ if label not in new_boxes:
99
+ new_boxes[label] = []
100
+ new_boxes[label].append(b)
101
+
102
+ # Sort each list in dict by score and transform it to numpy array
103
+ for k in new_boxes:
104
+ current_boxes = np.array(new_boxes[k])
105
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
106
+
107
+ return new_boxes
108
+
109
+
110
+ def get_weighted_box(boxes):
111
+ """
112
+ Create weighted box for set of boxes
113
+ :param boxes: set of boxes to fuse
114
+ :return: weighted box
115
+ """
116
+
117
+ box = np.zeros(6, dtype=np.float32)
118
+ best_box = boxes[0]
119
+ conf = 0
120
+ for b in boxes:
121
+ iou = bb_intersection_over_union(b[2:], best_box[2:])
122
+ weight = b[1] * iou
123
+ box[2:] += (weight * b[2:])
124
+ conf += weight
125
+ box[0] = best_box[0]
126
+ box[1] = best_box[1]
127
+ box[2:] /= conf
128
+ return box
129
+
130
+
131
+ def find_matching_box(boxes_list, new_box, match_iou):
132
+ best_iou = match_iou
133
+ best_index = -1
134
+ for i in range(len(boxes_list)):
135
+ box = boxes_list[i]
136
+ if box[0] != new_box[0]:
137
+ continue
138
+ iou = bb_intersection_over_union(box[2:], new_box[2:])
139
+ if iou > best_iou:
140
+ best_index = i
141
+ best_iou = iou
142
+
143
+ return best_index, best_iou
144
+
145
+
146
+ def non_maximum_weighted(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0):
147
+ '''
148
+ :param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
149
+ It has 3 dimensions (models_number, model_preds, 4)
150
+ Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
151
+ :param scores_list: list of scores for each model
152
+ :param labels_list: list of labels for each model
153
+ :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
154
+ :param iou_thr: IoU value for boxes to be a match
155
+ :param skip_box_thr: exclude boxes with score lower than this variable
156
+
157
+ :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
158
+ :return: scores: confidence scores
159
+ :return: labels: boxes labels
160
+ '''
161
+
162
+ if weights is None:
163
+ weights = np.ones(len(boxes_list))
164
+ if len(weights) != len(boxes_list):
165
+ print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
166
+ weights = np.ones(len(boxes_list))
167
+ weights = np.array(weights) / max(weights)
168
+ # for i in range(len(weights)):
169
+ # scores_list[i] = (np.array(scores_list[i]) * weights[i])
170
+
171
+ filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
172
+ if len(filtered_boxes) == 0:
173
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
174
+
175
+ overall_boxes = []
176
+ for label in filtered_boxes:
177
+ boxes = filtered_boxes[label]
178
+ new_boxes = []
179
+ main_boxes = []
180
+
181
+ # Clusterize boxes
182
+ for j in range(0, len(boxes)):
183
+ index, best_iou = find_matching_box(main_boxes, boxes[j], iou_thr)
184
+ if index != -1:
185
+ new_boxes[index].append(boxes[j].copy())
186
+ else:
187
+ new_boxes.append([boxes[j].copy()])
188
+ main_boxes.append(boxes[j].copy())
189
+
190
+ weighted_boxes = []
191
+ for j in range(0, len(new_boxes)):
192
+ box = get_weighted_box(new_boxes[j])
193
+ weighted_boxes.append(box.copy())
194
+
195
+ overall_boxes.append(np.array(weighted_boxes))
196
+
197
+ overall_boxes = np.concatenate(overall_boxes, axis=0)
198
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
199
+ boxes = overall_boxes[:, 2:]
200
+ scores = overall_boxes[:, 1]
201
+ labels = overall_boxes[:, 0]
202
+ return boxes, scores, labels
DenseMammogram/ensemble_boxes/ensemble_boxes_wbf.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3
+
4
+
5
+ import warnings
6
+ import numpy as np
7
+ from numba import jit
8
+ import time
9
+
10
+ @jit(nopython=True)
11
+ def bb_intersection_over_union(A, B) -> float:
12
+ xA = max(A[0], B[0])
13
+ yA = max(A[1], B[1])
14
+ xB = min(A[2], B[2])
15
+ yB = min(A[3], B[3])
16
+
17
+ # compute the area of intersection rectangle
18
+ interArea = max(0, xB - xA) * max(0, yB - yA)
19
+
20
+ if interArea == 0:
21
+ return 0.0
22
+
23
+ # compute the area of both the prediction and ground-truth rectangles
24
+ boxAArea = (A[2] - A[0]) * (A[3] - A[1])
25
+ boxBArea = (B[2] - B[0]) * (B[3] - B[1])
26
+
27
+ iou = interArea / float(boxAArea + boxBArea - interArea)
28
+ return iou
29
+
30
+
31
+ def prefilter_boxes(boxes, scores, labels, weights, thr):
32
+ # Create dict with boxes stored by its label
33
+ new_boxes = dict()
34
+
35
+ for t in range(len(boxes)):
36
+
37
+ if len(boxes[t]) != len(scores[t]):
38
+ print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
39
+ exit()
40
+
41
+ if len(boxes[t]) != len(labels[t]):
42
+ print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
43
+ exit()
44
+
45
+ for j in range(len(boxes[t])):
46
+ score = scores[t][j]
47
+ if score < thr:
48
+ continue
49
+ label = int(labels[t][j])
50
+ box_part = boxes[t][j]
51
+ x1 = float(box_part[0])
52
+ y1 = float(box_part[1])
53
+ x2 = float(box_part[2])
54
+ y2 = float(box_part[3])
55
+
56
+ # Box data checks
57
+ if x2 < x1:
58
+ warnings.warn('X2 < X1 value in box. Swap them.')
59
+ x1, x2 = x2, x1
60
+ if y2 < y1:
61
+ warnings.warn('Y2 < Y1 value in box. Swap them.')
62
+ y1, y2 = y2, y1
63
+ if x1 < 0:
64
+ warnings.warn('X1 < 0 in box. Set it to 0.')
65
+ x1 = 0
66
+ if x1 > 1:
67
+ warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
68
+ x1 = 1
69
+ if x2 < 0:
70
+ warnings.warn('X2 < 0 in box. Set it to 0.')
71
+ x2 = 0
72
+ if x2 > 1:
73
+ warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
74
+ x2 = 1
75
+ if y1 < 0:
76
+ warnings.warn('Y1 < 0 in box. Set it to 0.')
77
+ y1 = 0
78
+ if y1 > 1:
79
+ warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
80
+ y1 = 1
81
+ if y2 < 0:
82
+ warnings.warn('Y2 < 0 in box. Set it to 0.')
83
+ y2 = 0
84
+ if y2 > 1:
85
+ warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
86
+ y2 = 1
87
+ if (x2 - x1) * (y2 - y1) == 0.0:
88
+ warnings.warn("Zero area box skipped: {}.".format(box_part))
89
+ continue
90
+
91
+ # [label, score, weight, model index, x1, y1, x2, y2]
92
+ b = [int(label), float(score) * weights[t], weights[t], t, x1, y1, x2, y2]
93
+ if label not in new_boxes:
94
+ new_boxes[label] = []
95
+ new_boxes[label].append(b)
96
+
97
+ # Sort each list in dict by score and transform it to numpy array
98
+ for k in new_boxes:
99
+ current_boxes = np.array(new_boxes[k])
100
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
101
+
102
+ return new_boxes
103
+
104
+
105
+ def get_weighted_box(boxes, conf_type='avg'):
106
+ """
107
+ Create weighted box for set of boxes
108
+ :param boxes: set of boxes to fuse
109
+ :param conf_type: type of confidence one of 'avg' or 'max'
110
+ :return: weighted box (label, score, weight, x1, y1, x2, y2)
111
+ """
112
+
113
+ box = np.zeros(8, dtype=np.float32)
114
+ conf = 0
115
+ conf_list = []
116
+ w = 0
117
+ for b in boxes:
118
+ box[4:] += (b[1] * b[4:])
119
+ conf += b[1]
120
+ conf_list.append(b[1])
121
+ w += b[2]
122
+ box[0] = boxes[0][0]
123
+ if conf_type == 'avg':
124
+ box[1] = conf / len(boxes)
125
+ elif conf_type == 'max':
126
+ box[1] = np.array(conf_list).max()
127
+ elif conf_type in ['box_and_model_avg', 'absent_model_aware_avg']:
128
+ box[1] = conf / len(boxes)
129
+ box[2] = w
130
+ box[3] = -1 # model index field is retained for consistensy but is not used.
131
+ box[4:] /= conf
132
+ return box
133
+
134
+
135
+ def find_matching_box(boxes_list, new_box, match_iou):
136
+ best_iou = match_iou
137
+ best_index = -1
138
+ for i in range(len(boxes_list)):
139
+ box = boxes_list[i]
140
+ if box[0] != new_box[0]:
141
+ continue
142
+ iou = bb_intersection_over_union(box[4:], new_box[4:])
143
+ if iou > best_iou:
144
+ best_index = i
145
+ best_iou = iou
146
+
147
+ return best_index, best_iou
148
+
149
+
150
+ def find_matching_box_quickly(boxes_list, new_box, match_iou):
151
+ """ Reimplementation of find_matching_box with numpy instead of loops. Gives significant speed up for larger arrays
152
+ (~100x). This was previously the bottleneck since the function is called for every entry in the array.
153
+ """
154
+ def bb_iou_array(boxes, new_box):
155
+ # bb interesection over union
156
+ xA = np.maximum(boxes[:, 0], new_box[0])
157
+ yA = np.maximum(boxes[:, 1], new_box[1])
158
+ xB = np.minimum(boxes[:, 2], new_box[2])
159
+ yB = np.minimum(boxes[:, 3], new_box[3])
160
+
161
+ interArea = np.maximum(xB - xA, 0) * np.maximum(yB - yA, 0)
162
+
163
+ # compute the area of both the prediction and ground-truth rectangles
164
+ boxAArea = (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])
165
+ boxBArea = (new_box[2] - new_box[0]) * (new_box[3] - new_box[1])
166
+
167
+ iou = interArea / (boxAArea + boxBArea - interArea)
168
+
169
+ return iou
170
+
171
+ if boxes_list.shape[0] == 0:
172
+ return -1, match_iou
173
+
174
+ # boxes = np.array(boxes_list)
175
+ boxes = boxes_list
176
+
177
+ ious = bb_iou_array(boxes[:, 4:], new_box[4:])
178
+
179
+ ious[boxes[:, 0] != new_box[0]] = -1
180
+
181
+ best_idx = np.argmax(ious)
182
+ best_iou = ious[best_idx]
183
+
184
+ if best_iou <= match_iou:
185
+ best_iou = match_iou
186
+ best_idx = -1
187
+
188
+ return best_idx, best_iou
189
+
190
+
191
+ def weighted_boxes_fusion(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
192
+ '''
193
+ :param boxes_list: list of boxes predictions from each model, each box is 4 numbers.
194
+ It has 3 dimensions (models_number, model_preds, 4)
195
+ Order of boxes: x1, y1, x2, y2. We expect float normalized coordinates [0; 1]
196
+ :param scores_list: list of scores for each model
197
+ :param labels_list: list of labels for each model
198
+ :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
199
+ :param iou_thr: IoU value for boxes to be a match
200
+ :param skip_box_thr: exclude boxes with score lower than this variable
201
+ :param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value, 'box_and_model_avg': box and model wise hybrid weighted average, 'absent_model_aware_avg': weighted average that takes into account the absent model.
202
+ :param allows_overflow: false if we want confidence score not exceed 1.0
203
+
204
+ :return: boxes: boxes coordinates (Order of boxes: x1, y1, x2, y2).
205
+ :return: scores: confidence scores
206
+ :return: labels: boxes labels
207
+ '''
208
+
209
+ if weights is None:
210
+ weights = np.ones(len(boxes_list))
211
+ if len(weights) != len(boxes_list):
212
+ print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
213
+ weights = np.ones(len(boxes_list))
214
+ weights = np.array(weights)
215
+
216
+ if conf_type not in ['avg', 'max', 'box_and_model_avg', 'absent_model_aware_avg']:
217
+ print('Unknown conf_type: {}. Must be "avg", "max" or "box_and_model_avg", or "absent_model_aware_avg"'.format(conf_type))
218
+ exit()
219
+
220
+ filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
221
+ if len(filtered_boxes) == 0:
222
+ return np.zeros((0, 4)), np.zeros((0,)), np.zeros((0,))
223
+
224
+ overall_boxes = []
225
+ for label in filtered_boxes:
226
+ boxes = filtered_boxes[label]
227
+ new_boxes = []
228
+ weighted_boxes = np.empty((0,8))
229
+ # Clusterize boxes
230
+ for j in range(0, len(boxes)):
231
+ index, best_iou = find_matching_box_quickly(weighted_boxes, boxes[j], iou_thr)
232
+
233
+ if index != -1:
234
+ new_boxes[index].append(boxes[j])
235
+ weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
236
+ else:
237
+ new_boxes.append([boxes[j].copy()])
238
+ weighted_boxes = np.vstack((weighted_boxes, boxes[j].copy()))
239
+ # Rescale confidence based on number of models and boxes
240
+ for i in range(len(new_boxes)):
241
+ clustered_boxes = np.array(new_boxes[i])
242
+ if conf_type == 'box_and_model_avg':
243
+ # weighted average for boxes
244
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / weighted_boxes[i, 2]
245
+ # identify unique model index by model index column
246
+ _, idx = np.unique(clustered_boxes[:, 3], return_index=True)
247
+ # rescale by unique model weights
248
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] * clustered_boxes[idx, 2].sum() / weights.sum()
249
+ elif conf_type == 'absent_model_aware_avg':
250
+ # get unique model index in the cluster
251
+ models = np.unique(clustered_boxes[:, 3]).astype(int)
252
+ # create a mask to get unused model weights
253
+ mask = np.ones(len(weights), dtype=bool)
254
+ mask[models] = False
255
+ # absent model aware weighted average
256
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / (weighted_boxes[i, 2] + weights[mask].sum())
257
+ elif conf_type == 'max':
258
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] / weights.max()
259
+ elif not allows_overflow:
260
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] * min(len(weights), len(clustered_boxes)) / weights.sum()
261
+ else:
262
+ weighted_boxes[i, 1] = weighted_boxes[i, 1] * len(clustered_boxes) / weights.sum()
263
+ overall_boxes.append(weighted_boxes)
264
+ overall_boxes = np.concatenate(overall_boxes, axis=0)
265
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
266
+ boxes = overall_boxes[:, 4:]
267
+ scores = overall_boxes[:, 1]
268
+ labels = overall_boxes[:, 0]
269
+ return boxes, scores, labels
DenseMammogram/ensemble_boxes/ensemble_boxes_wbf_3d.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding: utf-8
2
+ __author__ = 'ZFTurbo: https://kaggle.com/zfturbo'
3
+
4
+
5
+ import warnings
6
+ import numpy as np
7
+ from numba import jit
8
+
9
+
10
+ @jit(nopython=True)
11
+ def bb_intersection_over_union_3d(A, B) -> float:
12
+ xA = max(A[0], B[0])
13
+ yA = max(A[1], B[1])
14
+ zA = max(A[2], B[2])
15
+ xB = min(A[3], B[3])
16
+ yB = min(A[4], B[4])
17
+ zB = min(A[5], B[5])
18
+
19
+ interVol = max(0, xB - xA) * max(0, yB - yA) * max(0, zB - zA)
20
+ if interVol == 0:
21
+ return 0.0
22
+
23
+ # compute the volume of both the prediction and ground-truth rectangular boxes
24
+ boxAVol = (A[3] - A[0]) * (A[4] - A[1]) * (A[5] - A[2])
25
+ boxBVol = (B[3] - B[0]) * (B[4] - B[1]) * (B[5] - B[2])
26
+
27
+ iou = interVol / float(boxAVol + boxBVol - interVol)
28
+ return iou
29
+
30
+
31
+ def prefilter_boxes(boxes, scores, labels, weights, thr):
32
+ # Create dict with boxes stored by its label
33
+ new_boxes = dict()
34
+
35
+ for t in range(len(boxes)):
36
+
37
+ if len(boxes[t]) != len(scores[t]):
38
+ print('Error. Length of boxes arrays not equal to length of scores array: {} != {}'.format(len(boxes[t]), len(scores[t])))
39
+ exit()
40
+
41
+ if len(boxes[t]) != len(labels[t]):
42
+ print('Error. Length of boxes arrays not equal to length of labels array: {} != {}'.format(len(boxes[t]), len(labels[t])))
43
+ exit()
44
+
45
+ for j in range(len(boxes[t])):
46
+ score = scores[t][j]
47
+ if score < thr:
48
+ continue
49
+ label = int(labels[t][j])
50
+ box_part = boxes[t][j]
51
+ x1 = float(box_part[0])
52
+ y1 = float(box_part[1])
53
+ z1 = float(box_part[2])
54
+ x2 = float(box_part[3])
55
+ y2 = float(box_part[4])
56
+ z2 = float(box_part[5])
57
+
58
+ # Box data checks
59
+ if x2 < x1:
60
+ warnings.warn('X2 < X1 value in box. Swap them.')
61
+ x1, x2 = x2, x1
62
+ if y2 < y1:
63
+ warnings.warn('Y2 < Y1 value in box. Swap them.')
64
+ y1, y2 = y2, y1
65
+ if z2 < z1:
66
+ warnings.warn('Z2 < Z1 value in box. Swap them.')
67
+ z1, z2 = z2, z1
68
+ if x1 < 0:
69
+ warnings.warn('X1 < 0 in box. Set it to 0.')
70
+ x1 = 0
71
+ if x1 > 1:
72
+ warnings.warn('X1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
73
+ x1 = 1
74
+ if x2 < 0:
75
+ warnings.warn('X2 < 0 in box. Set it to 0.')
76
+ x2 = 0
77
+ if x2 > 1:
78
+ warnings.warn('X2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
79
+ x2 = 1
80
+ if y1 < 0:
81
+ warnings.warn('Y1 < 0 in box. Set it to 0.')
82
+ y1 = 0
83
+ if y1 > 1:
84
+ warnings.warn('Y1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
85
+ y1 = 1
86
+ if y2 < 0:
87
+ warnings.warn('Y2 < 0 in box. Set it to 0.')
88
+ y2 = 0
89
+ if y2 > 1:
90
+ warnings.warn('Y2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
91
+ y2 = 1
92
+ if z1 < 0:
93
+ warnings.warn('Z1 < 0 in box. Set it to 0.')
94
+ z1 = 0
95
+ if z1 > 1:
96
+ warnings.warn('Z1 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
97
+ z1 = 1
98
+ if z2 < 0:
99
+ warnings.warn('Z2 < 0 in box. Set it to 0.')
100
+ z2 = 0
101
+ if z2 > 1:
102
+ warnings.warn('Z2 > 1 in box. Set it to 1. Check that you normalize boxes in [0, 1] range.')
103
+ z2 = 1
104
+ if (x2 - x1) * (y2 - y1) * (z2 - z1) == 0.0:
105
+ warnings.warn("Zero volume box skipped: {}.".format(box_part))
106
+ continue
107
+
108
+ b = [int(label), float(score) * weights[t], x1, y1, z1, x2, y2, z2]
109
+ if label not in new_boxes:
110
+ new_boxes[label] = []
111
+ new_boxes[label].append(b)
112
+
113
+ # Sort each list in dict by score and transform it to numpy array
114
+ for k in new_boxes:
115
+ current_boxes = np.array(new_boxes[k])
116
+ new_boxes[k] = current_boxes[current_boxes[:, 1].argsort()[::-1]]
117
+
118
+ return new_boxes
119
+
120
+
121
+ def get_weighted_box(boxes, conf_type='avg'):
122
+ """
123
+ Create weighted box for set of boxes
124
+ :param boxes: set of boxes to fuse
125
+ :param conf_type: type of confidence one of 'avg' or 'max'
126
+ :return: weighted box
127
+ """
128
+
129
+ box = np.zeros(8, dtype=np.float32)
130
+ conf = 0
131
+ conf_list = []
132
+ for b in boxes:
133
+ box[2:] += (b[1] * b[2:])
134
+ conf += b[1]
135
+ conf_list.append(b[1])
136
+ box[0] = boxes[0][0]
137
+ if conf_type == 'avg':
138
+ box[1] = conf / len(boxes)
139
+ elif conf_type == 'max':
140
+ box[1] = np.array(conf_list).max()
141
+ box[2:] /= conf
142
+ return box
143
+
144
+
145
+ def find_matching_box(boxes_list, new_box, match_iou):
146
+ best_iou = match_iou
147
+ best_index = -1
148
+ for i in range(len(boxes_list)):
149
+ box = boxes_list[i]
150
+ if box[0] != new_box[0]:
151
+ continue
152
+ iou = bb_intersection_over_union_3d(box[2:], new_box[2:])
153
+ if iou > best_iou:
154
+ best_index = i
155
+ best_iou = iou
156
+
157
+ return best_index, best_iou
158
+
159
+
160
+ def weighted_boxes_fusion_3d(boxes_list, scores_list, labels_list, weights=None, iou_thr=0.55, skip_box_thr=0.0, conf_type='avg', allows_overflow=False):
161
+ '''
162
+ :param boxes_list: list of boxes predictions from each model, each box is 6 numbers.
163
+ It has 3 dimensions (models_number, model_preds, 6)
164
+ Order of boxes: x1, y1, z1, x2, y2 z2. We expect float normalized coordinates [0; 1]
165
+ :param scores_list: list of scores for each model
166
+ :param labels_list: list of labels for each model
167
+ :param weights: list of weights for each model. Default: None, which means weight == 1 for each model
168
+ :param iou_thr: IoU value for boxes to be a match
169
+ :param skip_box_thr: exclude boxes with score lower than this variable
170
+ :param conf_type: how to calculate confidence in weighted boxes. 'avg': average value, 'max': maximum value
171
+ :param allows_overflow: false if we want confidence score not exceed 1.0
172
+
173
+ :return: boxes: boxes coordinates (Order of boxes: x1, y1, z1, x2, y2, z2).
174
+ :return: scores: confidence scores
175
+ :return: labels: boxes labels
176
+ '''
177
+
178
+ if weights is None:
179
+ weights = np.ones(len(boxes_list))
180
+ if len(weights) != len(boxes_list):
181
+ print('Warning: incorrect number of weights {}. Must be: {}. Set weights equal to 1.'.format(len(weights), len(boxes_list)))
182
+ weights = np.ones(len(boxes_list))
183
+ weights = np.array(weights)
184
+
185
+ if conf_type not in ['avg', 'max']:
186
+ print('Error. Unknown conf_type: {}. Must be "avg" or "max". Use "avg"'.format(conf_type))
187
+ conf_type = 'avg'
188
+
189
+ filtered_boxes = prefilter_boxes(boxes_list, scores_list, labels_list, weights, skip_box_thr)
190
+ if len(filtered_boxes) == 0:
191
+ return np.zeros((0, 6)), np.zeros((0,)), np.zeros((0,))
192
+
193
+ overall_boxes = []
194
+ for label in filtered_boxes:
195
+ boxes = filtered_boxes[label]
196
+ new_boxes = []
197
+ weighted_boxes = []
198
+
199
+ # Clusterize boxes
200
+ for j in range(0, len(boxes)):
201
+ index, best_iou = find_matching_box(weighted_boxes, boxes[j], iou_thr)
202
+ if index != -1:
203
+ new_boxes[index].append(boxes[j])
204
+ weighted_boxes[index] = get_weighted_box(new_boxes[index], conf_type)
205
+ else:
206
+ new_boxes.append([boxes[j].copy()])
207
+ weighted_boxes.append(boxes[j].copy())
208
+
209
+ # Rescale confidence based on number of models and boxes
210
+ for i in range(len(new_boxes)):
211
+ if not allows_overflow:
212
+ weighted_boxes[i][1] = weighted_boxes[i][1] * min(weights.sum(), len(new_boxes[i])) / weights.sum()
213
+ else:
214
+ weighted_boxes[i][1] = weighted_boxes[i][1] * len(new_boxes[i]) / weights.sum()
215
+ overall_boxes.append(np.array(weighted_boxes))
216
+
217
+ overall_boxes = np.concatenate(overall_boxes, axis=0)
218
+ overall_boxes = overall_boxes[overall_boxes[:, 1].argsort()[::-1]]
219
+ boxes = overall_boxes[:, 2:]
220
+ scores = overall_boxes[:, 1]
221
+ labels = overall_boxes[:, 0]
222
+ return boxes, scores, labels
DenseMammogram/experimenter.py ADDED
@@ -0,0 +1,213 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Experimenter Class is responsible for mainly four things:
2
+ # 1. Configuration - Done
3
+ # 2. Logging using the AdvancedLogger class - Almost Done
4
+ # 3. Model Handling, including loading and saving models - Done(Upgrades Left)
5
+ # 4. Running Different Variants Paralelly/Sequentially of experiments
6
+ # 5. Combining frcnn training followed by bilateral training and final froc calculation - Done
7
+ # 6. Version Control
8
+
9
+ from advanced_config import AdvancedConfig
10
+ from advanced_logger import AdvancedLogger, LogPriority
11
+ import os
12
+ from os.path import join
13
+ from plot_froc import plot_froc
14
+ from train_frcnn import main as TRAIN_FRCNN
15
+ from train_bilateral import main as TRAIN_BILATERAL
16
+ import torch
17
+ from model_utils import generate_predictions, generate_predictions_bilateral
18
+ import argparse
19
+ from dataloaders import get_dict
20
+ from utils import create_backup
21
+ from torch.utils.tensorboard import SummaryWriter
22
+
23
+ class Experimenter:
24
+
25
+ def __init__(self, cfg_file, BASE_DIR = 'experiments'):
26
+ self.cfg_file = cfg_file
27
+
28
+ self.con = AdvancedConfig(cfg_file)
29
+ self.config = self.con.config
30
+ self.exp_dir = join(BASE_DIR,self.config['EXP_NAME'])
31
+ os.makedirs(self.exp_dir, exist_ok=True)
32
+ self.con.save(join(self.exp_dir,'config.cfg'))
33
+
34
+ self.logger = AdvancedLogger(self.exp_dir)
35
+ self.logger.log('Experiment:',self.config['EXP_NAME'],priority = LogPriority.STATS)
36
+ self.logger.log('Experiment Description:', self.config['EXP_DESC'], priority = LogPriority.STATS)
37
+ self.logger.log('Config File:',self.cfg_file, priority = LogPriority.STATS)
38
+ self.logger.log('Experiment started', priority = LogPriority.LOW)
39
+ self.losses = dict()
40
+ self.frocs = dict()
41
+
42
+ self.writer = SummaryWriter(join(self.exp_dir,'tensor_logs'))
43
+
44
+ create_backup(backup_dir=join(self.exp_dir,'scripts'))
45
+
46
+ def log(self, *args, **kwargs):
47
+ self.logger.log(*args, **kwargs)
48
+
49
+
50
+ def init_losses(self,mode):
51
+ if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL':
52
+ self.losses['frcnn_loss'] = []
53
+ self.frocs['frcnn_froc'] = []
54
+ elif mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL':
55
+ self.losses['bilateral_loss'] = []
56
+ self.frocs['bilateral_froc'] = []
57
+
58
+ def start_epoch(self):
59
+ self.curr_epoch += 1
60
+ self.logger.log('Epoch:',self.curr_epoch, priority = LogPriority.MEDIUM)
61
+
62
+ def end_epoch(self, loss, model = None, device = None):
63
+ if self.curr_mode == 'FRCNN':
64
+ self.losses['frcnn_loss'].append(loss)
65
+ self.best_loss = min(self.losses['frcnn_loss'])
66
+ if self.config['EVAL_METHOD'] == 'FROC':
67
+ exp_name = self.config['EXP_NAME']
68
+ _, val_path, _ = self.init_paths()
69
+ generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}')
70
+ from froc_by_pranjal import get_froc_points
71
+ senses, _ = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.2])
72
+ self.frocs['frcnn_froc'].append(senses[0])
73
+ self.best_froc = max(self.frocs['frcnn_froc'])
74
+ self.logger.log(f'Val FROC: {senses[0]}', LogPriority.MEDIUM)
75
+ self.logger.log(f'Best FROC: {self.best_froc}')
76
+ elif self.curr_mode == 'BILATERAL':
77
+ self.losses['bilateral_loss'].append(loss)
78
+ self.best_loss = min(self.losses['bilateral_loss'])
79
+ if self.config['EVAL_METHOD'] == 'FROC':
80
+ exp_name = self.config['EXP_NAME']
81
+ _, val_path, _ = self.init_paths()
82
+ data_dir = self.config['DATA_DIR']
83
+ print('Generating')
84
+ generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),preds_folder = f'preds_bilateral_{exp_name}')
85
+ print('Generation Done')
86
+ from froc_by_pranjal import get_froc_points
87
+ senses, _ = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']), fps_req = [0.1])
88
+ print('Reading Sens from',f'preds_bilateral_{exp_name}', join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT']),)
89
+
90
+ self.frocs['bilateral_froc'].append(senses[0])
91
+ self.best_froc = max(self.frocs['bilateral_froc'])
92
+ self.logger.log(f'Val FROC: {senses[0]}', priority = LogPriority.MEDIUM)
93
+ self.logger.log(f'Best FROC: {self.best_froc}')
94
+
95
+ self.writer.add_scalar(f"{self.curr_mode}/Loss/Valid", loss, self.curr_epoch)
96
+
97
+
98
+
99
+ def save_model(self, model):
100
+ if self.curr_mode == 'FRCNN':
101
+ self.logger.log('Saving FRCNN Model', priority = LogPriority.LOW)
102
+ model_file = join(self.exp_dir,'frcnn_models',f'frcnn_model.pth')
103
+ if self.config['EVAL_METHOD']:
104
+ SAVE = self.best_froc == self.frocs['frcnn_froc'][-1]
105
+ else:
106
+ SAVE = self.best_loss == self.losses['frcnn_loss'][-1]
107
+ elif self.curr_mode == 'BILATERAL':
108
+ self.logger.log('Saving Bilateral Model', priority = LogPriority.LOW)
109
+ model_file = join(self.exp_dir,'bilateral_models',f'bilateral_model.pth')
110
+ if self.config['EVAL_METHOD'] == 'FROC':
111
+ SAVE = self.best_froc == self.frocs['bilateral_froc'][-1]
112
+ else:
113
+ SAVE = self.best_loss == self.losses['bilateral_loss'][-1]
114
+ os.makedirs(os.path.split(model_file)[0], exist_ok=True)
115
+ if SAVE:
116
+ torch.save(model.state_dict(), model_file)
117
+
118
+ torch.save(model.state_dict(), f'{model_file[:-4]}_{self.curr_epoch}.pth')
119
+
120
+ def init_paths(self,):
121
+ train_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TRAIN_SPLIT'])
122
+ val_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_VAL_SPLIT'])
123
+ test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT'])
124
+ return train_path, val_path, test_path
125
+
126
+ def abs_path(self, path):
127
+ return join(self.config['DATA_DIR'], path)
128
+
129
+ # Impure Function, upadtes the model with best state dicts
130
+ def generate_predictions(self,model, device):
131
+ self.logger.log('Generating Predictions')
132
+ self.logger.flush()
133
+ exp_name = self.config['EXP_NAME']
134
+ train_path, val_path, test_path = self.init_paths()
135
+
136
+ # Load the best val_loss model's state dicts
137
+ if self.curr_mode == 'FRCNN':
138
+ model_file = join(self.exp_dir,'frcnn_models','frcnn_model.pth')
139
+ elif self.curr_mode == 'BILATERAL':
140
+ model_file = join(self.exp_dir,'bilateral_models','bilateral_model.pth')
141
+ model.load_state_dict(torch.load(model_file))
142
+
143
+ if self.curr_mode == 'FRCNN':
144
+ generate_predictions(model,device,train_path,f'preds_frcnn_{exp_name}')
145
+ generate_predictions(model,device,val_path,f'preds_frcnn_{exp_name}')
146
+ generate_predictions(model,device,test_path,f'preds_frcnn_{exp_name}')
147
+ elif self.curr_mode == 'BILATERAL':
148
+ data_dir = self.config['DATA_DIR']
149
+ generate_predictions_bilateral(model,device,train_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
150
+ generate_predictions_bilateral(model,device,val_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
151
+ generate_predictions_bilateral(model,device,test_path,get_dict(data_dir,self.abs_path(self.config['AIIMS_CORRS_LIST'])),'aiims',f'preds_bilateral_{exp_name}')
152
+ test_path = join(self.config['DATA_DIR'], self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT'])
153
+
154
+ def run_experiment(self):
155
+
156
+ # First Determine the mode of running the experiment
157
+ mode = self.config['MODE']
158
+ self.init_losses(mode)
159
+ self.curr_mode = 'FRCNN'
160
+ self.curr_epoch = -1
161
+ self.best_loss = 999999
162
+ self.best_froc = 0
163
+ if mode == 'FRCNN':
164
+ TRAIN_FRCNN(self.config['FRCNN'], self)
165
+ elif mode == 'BILATERAL':
166
+ self.curr_mode = 'BILATERAL'
167
+ TRAIN_BILATERAL(self.config['BILATERAL'], self)
168
+ elif mode == 'FRCNN_BILATERAL':
169
+ TRAIN_FRCNN(self.config['FRCNN'], self)
170
+ self.curr_mode = 'BILATERAL'
171
+ self.curr_epoch = -1
172
+ self.best_loss = 999999
173
+ # Note the path to frcnn model must be the same as that dictated by experiment
174
+ self.config['BILATERAL']['FRCNN_MODEL_PATH'] = join(self.exp_dir,'frcnn_models','frcnn_model.pth')
175
+ TRAIN_BILATERAL(self.config['BILATERAL'], self)
176
+
177
+ self.logger.log(f'Best Loss: {self.best_loss}', priority= LogPriority.STATS)
178
+ self.logger.log('Experiment Training and Generation Ended', priority = LogPriority.MEDIUM)
179
+
180
+ # Now evaluate the results
181
+
182
+ frcnn_file = join(self.exp_dir, 'senses_fps_frcnn.txt')
183
+ bilateral_file = join(self.exp_dir, 'senses_fps_bilateral.txt')
184
+ from froc_by_pranjal import get_froc_points
185
+ exp_name = self.config['EXP_NAME']
186
+ if mode == 'FRCNN' or mode == 'FRCNN_BILATERAL':
187
+ senses, fps = get_froc_points(f'preds_frcnn_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = frcnn_file)
188
+ self.logger.log('FRCNN RESULTS', priority = LogPriority.STATS)
189
+ for s,f in zip(senses, fps):
190
+ self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS)
191
+ if mode == 'BILATERAL' or mode == 'FRCNN_BILATERAL':
192
+ senses, fps = get_froc_points(f'preds_bilateral_{exp_name}', root_fol= join(self.config['DATA_DIR'],self.config['AIIMS_DATA'], self.config['AIIMS_TEST_SPLIT']), save_to = bilateral_file)
193
+ self.logger.log('BILATERAL RESULTS', priority = LogPriority.STATS)
194
+ for s,f in zip(senses, fps):
195
+ self.logger.log(f'Sensitivty at {f}: {s}', priority = LogPriority.STATS)
196
+
197
+
198
+ # Now draw the graphs.... If FRCNN and BILATERAL both done, draw them on one graph
199
+ # Else draw single graphs only
200
+ if mode == 'FRCNN':
201
+ plot_froc({frcnn_file : 'FRCNN'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN FROC')
202
+ elif mode == 'BILATERAL':
203
+ plot_froc({bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'BILATERAL FROC')
204
+ elif mode == 'FRCNN_BILATERAL':
205
+ plot_froc({frcnn_file : 'FRCNN', bilateral_file : 'BILATERAL'}, join(self.exp_dir,'plot.png'), TITLE = 'FRCNN vs BILATERAL FROC')
206
+ self.logger.flush()
207
+
208
+ if __name__ == '__main__':
209
+ parser = argparse.ArgumentParser()
210
+ parser.add_argument('--cfg_file', type=str, default='configs/AIIMS_C1.cfg')
211
+ args = parser.parse_args()
212
+ exp = Experimenter(args.cfg_file)
213
+ exp.run_experiment()
DenseMammogram/froc_by_pranjal.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import sys
4
+ from os.path import join
5
+
6
+
7
+ '''
8
+ Note: Anywhere empty boxes means [] and not [[]]
9
+ '''
10
+
11
+
12
+ def remove_true_positives(gts, preds):
13
+
14
+ def true_positive(gt, pred):
15
+ # If center of pred is inside the gt, it is a true positive
16
+ c_pred = ((pred[0]+pred[2])/2., (pred[1]+pred[3])/2.)
17
+ if (c_pred[0] >= gt[0] and c_pred[0] <= gt[2] and
18
+ c_pred[1] >= gt[1] and c_pred[1] <= gt[3]):
19
+ return True
20
+ return False
21
+
22
+ tps = 0
23
+ fns = 0
24
+
25
+ for gt in gts:
26
+ # First check if any true positive exists
27
+ # If more than one exists, do not include it in next set of preds
28
+ add_tp = False
29
+ new_preds = []
30
+ for pred in preds:
31
+ if true_positive(gt, pred):
32
+ add_tp = True
33
+ else:
34
+ new_preds.append(pred)
35
+ preds = new_preds
36
+ if add_tp:
37
+ tps += 1
38
+ else:
39
+ fns += 1
40
+ return preds, tps, fns
41
+
42
+
43
+
44
+ def calc_metric_single(gts, preds, threshold,):
45
+ '''
46
+ Returns fp, tp, tn, fn
47
+ '''
48
+ preds = list(filter(lambda x: x[0] >= threshold, preds))
49
+ preds = [pred[1:] for pred in preds] # Remove the scores
50
+
51
+ if len(gts) == 0:
52
+ return len(preds), 0, 1 if len(preds) == 0 else 0, 0
53
+ preds, tps, fns = remove_true_positives(gts, preds)
54
+ # All remaining will have to fps
55
+ fps = len(preds)
56
+ return fps, tps, 0, fns
57
+
58
+
59
+ def calc_metrics_at_thresh(im_dict, threshold):
60
+ '''
61
+ Returns fp, tp, tn, fn
62
+ '''
63
+ fps, tps, tns, fns = 0, 0, 0, 0
64
+ for key in im_dict:
65
+ fp,tp,tn,fn = calc_metric_single(im_dict[key]['gt'],
66
+ im_dict[key]['preds'], threshold)
67
+ fps+=fp
68
+ tps+=tp
69
+ tns+=tn
70
+ fns+=fn
71
+
72
+ return fps, tps, tns, fns
73
+
74
+ from joblib import Parallel, delayed
75
+
76
+ def calc_metrics(inp):
77
+ im_dict, tr = inp
78
+ out = dict()
79
+ for t in tr:
80
+ fp, tp, tn, fn = calc_metrics_at_thresh(im_dict, t)
81
+ out[t] = [fp, tp, tn, fn]
82
+ return out
83
+
84
+
85
+ def calc_froc_from_dict(im_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None):
86
+
87
+ num_images = len(im_dict)
88
+
89
+ gap = 0.005
90
+ n = int(1/gap)
91
+ thresholds = [i * gap for i in range(n)]
92
+ fps = [0 for _ in range(n)]
93
+ tps = [0 for _ in range(n)]
94
+ tns = [0 for _ in range(n)]
95
+ fns = [0 for _ in range(n)]
96
+
97
+
98
+ for i,t in enumerate(thresholds):
99
+ fps[i], tps[i], tns[i], fns[i] = calc_metrics_at_thresh(im_dict, t)
100
+
101
+
102
+ # Now calculate the sensitivities
103
+ senses = []
104
+ for t,f in zip(tps, fns):
105
+ try: senses.append(t/(t+f))
106
+ except: senses.append(0.)
107
+
108
+ if save_to is not None:
109
+ f = open(save_to, 'w')
110
+ for fp,s in zip(fps, senses):
111
+ f.write(f'{fp/num_images} {s}\n')
112
+ f.close()
113
+
114
+ senses_req = []
115
+ for fp_req in fps_req:
116
+ for i,f in enumerate(fps):
117
+ if f/num_images < fp_req:
118
+ if fp_req == 0.1:
119
+ print(fps[i], tps[i], tns[i], fns[i])
120
+ prec = tps[i]/(tps[i] + fps[i])
121
+ recall = tps[i]/(tps[i] + fns[i])
122
+ f1 = 2*prec*recall/(prec+recall)
123
+ spec = tns[i]/ (tns[i] + fps[i])
124
+ print(f'Specificity: {spec}')
125
+ print(f'Precision: {prec}')
126
+ print(f'Recall: {recall}')
127
+ print(f'F1: {f1}')
128
+ senses_req.append(senses[i-1])
129
+ break
130
+ return senses_req, fps_req
131
+
132
+
133
+
134
+
135
+ def file_to_bbox(file_name):
136
+ try:
137
+ content = open(file_name, 'r').readlines()
138
+ st = 0
139
+ if len(content) == 0:
140
+ # Empty File Should Return []
141
+ return []
142
+ if content[0].split()[0].isalpha():
143
+ st = 1
144
+ return [[float(x) for x in line.split()[st:]] for line in content]
145
+ except FileNotFoundError:
146
+ print(f'No Corresponding Box Found for file {file_name}, using [] as preds')
147
+ return []
148
+ except Exception as e:
149
+ print('Some Error',e)
150
+ return []
151
+
152
+ def generate_image_dict(preds_folder_name='preds_42',
153
+ root_fol='/home/pranjal/densebreeast_datasets/AIIMS_C1',
154
+ mal_path=None, ben_path=None, gt_path=None,
155
+ mal_img_path = None, ben_img_path = None
156
+ ):
157
+
158
+ mal_path = join(root_fol, mal_path) if mal_path else join(
159
+ root_fol, 'mal', preds_folder_name)
160
+ ben_path = join(root_fol, ben_path) if ben_path else join(
161
+ root_fol, 'ben', preds_folder_name)
162
+ mal_img_path = join(root_fol, mal_img_path) if mal_img_path else join(
163
+ root_fol, 'mal', 'images')
164
+ ben_img_path = join(root_fol, ben_img_path) if ben_img_path else join(
165
+ root_fol, 'ben', 'images')
166
+ gt_path = join(root_fol, gt_path) if gt_path else join(
167
+ root_fol, 'mal', 'gt')
168
+
169
+
170
+ '''
171
+ image_dict structure:
172
+ 'image_name(without txt/png)' : {'gt' : [[...]], 'preds' : [[]]}
173
+ '''
174
+ image_dict = dict()
175
+
176
+ # GT Might be sightly different from images, therefore we will index gts based on
177
+ # the images folder instead.
178
+ for file in os.listdir(mal_img_path):
179
+ if not file.endswith('.png'):
180
+ continue
181
+ file = file[:-4] + '.txt'
182
+ file = join(gt_path, file)
183
+ key = os.path.split(file)[-1][:-4]
184
+ image_dict[key] = dict()
185
+ image_dict[key]['gt'] = file_to_bbox(file)
186
+ image_dict[key]['preds'] = []
187
+
188
+ for file in glob.glob(join(mal_path, '*.txt')):
189
+ key = os.path.split(file)[-1][:-4]
190
+ assert key in image_dict
191
+ image_dict[key]['preds'] = file_to_bbox(file)
192
+
193
+ for file in os.listdir(ben_img_path):
194
+ if not file.endswith('.png'):
195
+ continue
196
+
197
+ file = file[:-4] + '.txt'
198
+ file = join(ben_path, file)
199
+ key = os.path.split(file)[-1][:-4]
200
+ if key == 'Calc-Test_P_00353_LEFT_CC' or key == 'Calc-Training_P_00600_LEFT_CC': # Corrupt Files in Dataset
201
+ continue
202
+ if key in image_dict:
203
+ print(key)
204
+ # assert key not in image_dict
205
+ if key in image_dict:
206
+ print(f'Unexpected Error. {key} exists in multiple splits')
207
+ continue
208
+ image_dict[key] = dict()
209
+ image_dict[key]['preds'] = file_to_bbox(file)
210
+ image_dict[key]['gt'] = []
211
+ return image_dict
212
+
213
+
214
+ def pretty_print_fps(senses,fps):
215
+ for s,f in zip(senses,fps):
216
+ print(f'Sensitivty at {f}: {s}')
217
+
218
+ def get_froc_points(preds_image_folder, root_fol, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3], save_to = None):
219
+ im_dict = generate_image_dict(preds_image_folder, root_fol = root_fol)
220
+ # print(im_dict)
221
+ print(len(im_dict))
222
+ senses, fps = calc_froc_from_dict(im_dict, fps_req, save_to = save_to)
223
+ return senses, fps
224
+
225
+ if __name__ == '__main__':
226
+ seed = '42' if len(sys.argv)== 1 else sys.argv[1]
227
+
228
+ root_fol = '../bilateral_new/MammoDatasets/AIIMS_highres_reliable/test_2'
229
+
230
+ if len(sys.argv) <= 2:
231
+ save_to = None
232
+ else:
233
+ save_to = sys.argv[2]
234
+ senses, fps = get_froc_points(f'preds_{seed}',root_fol, save_to = save_to)
235
+
236
+ pretty_print_fps(senses, fps)
DenseMammogram/geenerate_aiims.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from os.path import join
4
+ from model_utils import generate_predictions, generate_predictions_bilateral
5
+ from models import get_FRCNN_model, Bilateral_model
6
+ from froc_by_pranjal import get_froc_points
7
+ from auc_by_pranjal import get_auc_score
8
+
9
+ ####### PARAMETERS TO ADJUST #######
10
+ exp_name = 'BILATERAL'
11
+ OUT_FILE = 'aiims_full_test_results/bil_complete.txt'
12
+ BILATERAL = True
13
+ dataset_path = 'AIIMS_highres_reliable/test_2'
14
+ ####################################
15
+
16
+
17
+
18
+
19
+ if os.path.split(OUT_FILE)[0]:
20
+ os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ frcnn_model = get_FRCNN_model().to(device)
24
+
25
+ if BILATERAL:
26
+ model = Bilateral_model(frcnn_model).to(device)
27
+ MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
28
+ model.load_state_dict(torch.load(MODEL_PATH))
29
+ else:
30
+ model = frcnn_model
31
+ MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
32
+ model.load_state_dict(torch.load(MODEL_PATH))
33
+
34
+
35
+ test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
36
+
37
+
38
+ def get_aiims_dict(test_path, corr_file):
39
+ extract_file = lambda x: x[x.find('test_2/')+7:]
40
+ corr_dict = {extract_file(line.split()[0].replace('"','')):extract_file(line.split()[1].replace('"','')) for line in open(corr_file).readlines()}
41
+ corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
42
+ return corr_dict
43
+
44
+ if BILATERAL:
45
+ pred_dir = f'preds_bilateral_{exp_name}'
46
+ generate_predictions_bilateral(model,device,test_path, get_aiims_dict(test_path, '../bilateral_new/corr_lists/aiims_corr_list_with_val_full_test.txt'),'aiims',pred_dir)
47
+ else:
48
+ pred_dir = f'preds_frcnn_{exp_name}'
49
+ generate_predictions(model, device, test_path, preds_folder = pred_dir)
50
+
51
+
52
+ file = open(OUT_FILE, 'a')
53
+ file.writelines(f'{exp_name} FROC Score:\n')
54
+ senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
55
+ for s,f in zip(senses, fps):
56
+ print(f'Sensitivty at {f}: {s}')
57
+ file.writelines(f'Sensitivty at {f}: {s}\n')
58
+ file.close()
59
+
60
+ print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
61
+
DenseMammogram/geenerate_ddsm_preds.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from os.path import join
4
+ from model_utils import generate_predictions, generate_predictions_bilateral
5
+ from models import get_FRCNN_model, Bilateral_model
6
+ from froc_by_pranjal import get_froc_points
7
+ from auc_by_pranjal import get_auc_score
8
+
9
+ ####### PARAMETERS TO ADJUST #######
10
+ exp_name = 'frcnn_16'
11
+ OUT_FILE = 'ddsm_results/ddsm_dset.txt'
12
+ BILATERAL = False
13
+ dataset_path = 'ddsm_data_no_proc_2100_nocrop/val'
14
+ ####################################
15
+
16
+
17
+
18
+
19
+ if os.path.split(OUT_FILE)[0]:
20
+ os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ frcnn_model = get_FRCNN_model().to(device)
24
+
25
+ if BILATERAL:
26
+ model = Bilateral_model(frcnn_model).to(device)
27
+ MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
28
+ model.load_state_dict(torch.load(MODEL_PATH))
29
+ else:
30
+ model = frcnn_model
31
+ MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
32
+ model.load_state_dict(torch.load(MODEL_PATH))
33
+
34
+
35
+ test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
36
+
37
+
38
+ def get_ddsm_dict(test_path, corr_file):
39
+ extract_file = lambda x: x[x.find('val/')+4:]
40
+ corr_dict = {extract_file(line.split()[0].replace('"','')):extract_file(line.split()[1].replace('"','')) for line in open(corr_file).readlines()}
41
+ corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
42
+ return corr_dict
43
+
44
+ if BILATERAL:
45
+ pred_dir = f'preds_bilateral_{exp_name}'
46
+ generate_predictions_bilateral(model,device,test_path, get_ddsm_dict(test_path, '../bilateral_new/corr_lists/ddsm_corr_list_with_val.txt'),'ddsm',pred_dir)
47
+ else:
48
+ pred_dir = f'preds_frcnn_{exp_name}'
49
+ generate_predictions(model, device, test_path, preds_folder = pred_dir)
50
+
51
+
52
+ file = open(OUT_FILE, 'a')
53
+ file.writelines(f'{exp_name} FROC Score:\n')
54
+ senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
55
+ for s,f in zip(senses, fps):
56
+ print(f'Sensitivty at {f}: {s}')
57
+ file.writelines(f'Sensitivty at {f}: {s}\n')
58
+ file.close()
59
+
60
+ print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
61
+
DenseMammogram/geenerate_inbreast_preds.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from os.path import join
4
+ from model_utils import generate_predictions, generate_predictions_bilateral
5
+ from models import get_FRCNN_model, Bilateral_model
6
+ from froc_by_pranjal import get_froc_points
7
+
8
+ ####### PARAMETERS TO ADJUST #######
9
+ exp_name = 'AIIMS_C3'
10
+ OUT_FILE = 'ib_results/c3_frcnn.txt'
11
+ BILATERAL = False
12
+ dataset_path = 'INBREAST_C3/test'
13
+ ####################################
14
+
15
+
16
+
17
+
18
+ if os.path.split(OUT_FILE)[0]:
19
+ os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
20
+
21
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
22
+ frcnn_model = get_FRCNN_model().to(device)
23
+
24
+ if BILATERAL:
25
+ model = Bilateral_model(frcnn_model).to(device)
26
+ MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
27
+ model.load_state_dict(torch.load(MODEL_PATH))
28
+ else:
29
+ model = frcnn_model
30
+ MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
31
+ model.load_state_dict(torch.load(MODEL_PATH))
32
+
33
+
34
+ test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
35
+
36
+
37
+ def get_inbreast_dict(test_path, corr_file):
38
+ extract_file = lambda x: x[x.find('test/')+5:]
39
+ corr_dict = {extract_file(line.split()[0]):extract_file(line.split()[1]) for line in open(corr_file).readlines()}
40
+ corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
41
+ return corr_dict
42
+
43
+ if BILATERAL:
44
+ pred_dir = f'preds_bilateral_{exp_name}'
45
+ generate_predictions_bilateral(model,device,test_path, get_inbreast_dict(test_path, '../bilateral_new/corr_lists/Inbreast_final_correspondence_list.txt'),'inbreast',pred_dir)
46
+ else:
47
+ pred_dir = f'preds_frcnn_{exp_name}'
48
+ generate_predictions(model, device, test_path, preds_folder = pred_dir)
49
+
50
+
51
+ file = open(OUT_FILE, 'a')
52
+ file.writelines(f'{exp_name} FROC Score:\n')
53
+ senses, fps = get_froc_points(pred_dir, root_fol= test_path)
54
+ for s,f in zip(senses, fps):
55
+ file.writelines(f'Sensitivty at {f}: {s}\n')
56
+ file.close()
57
+
DenseMammogram/geenerate_irch.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from os.path import join
4
+ from model_utils import generate_predictions, generate_predictions_bilateral
5
+ from models import get_FRCNN_model, Bilateral_model
6
+ from froc_by_pranjal import get_froc_points
7
+ from auc_by_pranjal import get_auc_score
8
+
9
+ ####### PARAMETERS TO ADJUST #######
10
+ exp_name = 'BILATERAL'
11
+ OUT_FILE = 'irchvalres/bil_final.txt'
12
+ BILATERAL = True
13
+ dataset_path = 'IRCHVal'
14
+ ####################################
15
+
16
+
17
+
18
+
19
+ if os.path.split(OUT_FILE)[0]:
20
+ os.makedirs(os.path.split(OUT_FILE)[0], exist_ok=True)
21
+
22
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
23
+ frcnn_model = get_FRCNN_model().to(device)
24
+
25
+ if BILATERAL:
26
+ model = Bilateral_model(frcnn_model).to(device)
27
+ MODEL_PATH = f'experiments/{exp_name}/bilateral_models/bilateral_model.pth'
28
+ model.load_state_dict(torch.load(MODEL_PATH))
29
+ else:
30
+ model = frcnn_model
31
+ MODEL_PATH = f'experiments/{exp_name}/frcnn_models/frcnn_model.pth'
32
+ model.load_state_dict(torch.load(MODEL_PATH))
33
+
34
+
35
+ test_path = join('../bilateral_new', 'MammoDatasets',dataset_path)
36
+
37
+
38
+ def get_aiims_dict(test_path, corr_file):
39
+ extract_file = lambda x: x
40
+ corr_dict = {extract_file(line.split('" "')[0].strip().replace('"','')):extract_file(line.split('" "')[1].strip().replace('"','')) for line in open(corr_file).readlines()}
41
+ corr_dict = {join(test_path,k):join(test_path,v) for k,v in corr_dict.items()}
42
+ print(list(corr_dict.keys())[:20])
43
+ return corr_dict
44
+
45
+ if BILATERAL:
46
+ pred_dir = f'preds_bilateral_{exp_name}'
47
+ generate_predictions_bilateral(model,device,test_path, get_aiims_dict(test_path, '../bilateral_new/corr_lists/irch_val.txt'),'irch',pred_dir)
48
+ else:
49
+ pred_dir = f'preds_frcnn_{exp_name}'
50
+ generate_predictions(model, device, test_path, preds_folder = pred_dir)
51
+
52
+
53
+ file = open(OUT_FILE, 'a')
54
+ file.writelines(f'{exp_name} FROC Score:\n')
55
+ senses, fps = get_froc_points(pred_dir, root_fol= test_path, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.0,1.5])
56
+ for s,f in zip(senses, fps):
57
+ print(f'Sensitivty at {f}: {s}')
58
+ file.writelines(f'Sensitivty at {f}: {s}\n')
59
+ file.close()
60
+
61
+ print('AUC Score:',get_auc_score(pred_dir, test_path, retAcc = True, acc_thresh = 1.))
62
+
DenseMammogram/merge_predictions.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import torch
4
+ from os.path import join
5
+ import numpy as np
6
+ from froc_by_pranjal import file_to_bbox, calc_froc_from_dict, pretty_print_fps
7
+ import sys
8
+ from ensemble_boxes import *
9
+ import json
10
+ import pickle
11
+
12
+
13
+
14
+ get_file_id = lambda x: x.split('_')[1]
15
+ get_acr_cat = lambda x: '0' if x not in acr_cat else acr_cat[x]
16
+ cat_to_idx = {'a':1,'b':2,'c':3,'d':4}
17
+
18
+
19
+ def get_image_dict(dataset_paths, labels = ['mal','ben'], allowed = [], USE_ACR = False, acr_cat = None, mp_dict = None):
20
+ image_dict = dict()
21
+ if allowed == []:
22
+ allowed = [i for i in range(len(dataset_paths))]
23
+ for label in labels:
24
+ images = list(set.intersection(*map(set, [os.listdir(dset.format(label)) for dset in dataset_paths])))
25
+ for image in images:
26
+ if USE_ACR:
27
+ acr = get_acr_cat(get_file_id(image))
28
+ # print(acr, image)
29
+ key = image[:-4]
30
+ gts = []
31
+ preds = []
32
+ for i,dset in enumerate(dataset_paths):
33
+ if i not in allowed:
34
+ continue
35
+ if USE_ACR:
36
+ if dset.find('AIIMS_C')!=-1:
37
+ if acr == '0': continue
38
+ if dset.find(f'AIIMS_C{cat_to_idx[acr]}') == -1:
39
+ continue
40
+ # Now choose dset to be the acr category one
41
+ dset = dset.replace('/test',f'/test_{acr}')
42
+ # print('ds',dset)
43
+ pred_file = join(dset.format(label), key+'.txt')
44
+ gt_file = join(os.path.split(dset.format(label))[0],'gt', key+'.txt')
45
+ if label == 'mal':
46
+ gts.append(file_to_bbox(gt_file))
47
+ else:
48
+ gts.append([])
49
+
50
+ # TODO: Note this
51
+ flag = False
52
+ for mp in mp_dict:
53
+ if dataset_paths[i].find(mp) != -1:
54
+ preds.append(mp_dict[mp](file_to_bbox(pred_file)))
55
+ flag = True
56
+ break
57
+ if not flag:
58
+ preds.append(file_to_bbox(pred_file))
59
+
60
+ # Ensure all gts are same
61
+ gt = gts[0]
62
+ for g in gts[1:]:
63
+ assert g == gt
64
+ gt = g
65
+
66
+ # Flatten Preds
67
+ preds = [np.array(p) for p in preds]
68
+ preds = [np.array([[0.,0.,0.,0.,0.]]) if pred.shape==(0,) else pred for pred in preds]
69
+ preds = [np.vstack((p, np.zeros((100 - len(p), 5)))) for p in preds]
70
+ image_dict[key] = dict()
71
+ image_dict[key]['gt'] = gts[0]
72
+ image_dict[key]['preds'] = preds
73
+ return image_dict
74
+
75
+
76
+ def apply_merge(image_dict, METHOD = 'wbf', weights = None, conf_type = None):
77
+ FACTOR = 5000
78
+ fusion_func = weighted_boxes_fusion if METHOD == 'wbf' else non_maximum_weighted
79
+ for key in image_dict:
80
+ preds = np.array(image_dict[key]['preds'])
81
+ if len(preds) != 0:
82
+ boxes_list = [pred[:,1:]/FACTOR for pred in preds]
83
+ scores_list = [pred[:,0] for pred in preds]
84
+ labels = [[0. for _ in range(len(p))] for p in preds]
85
+ if weights is None:
86
+ weights = [1 for _ in range(len(preds))]
87
+ if METHOD == 'wbf' and conf_type is not None:
88
+ boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5, conf_type = conf_type)
89
+ else:
90
+ boxes,scores,_ = fusion_func(boxes_list, scores_list, labels, weights = weights,iou_thr = 0.5,)
91
+ preds_t = [[scores[i],FACTOR*boxes[i][0],FACTOR*boxes[i][1],FACTOR*boxes[i][2],FACTOR*boxes[i][3]] for i in range(len(boxes))]
92
+ image_dict[key]['preds'] = preds_t
93
+ return image_dict
94
+
95
+ def manipulate_preds(preds):
96
+ return preds
97
+
98
+
99
+
100
+ def manipulate_preds_4(preds):
101
+ return preds
102
+
103
+ tot = 0
104
+ def manipulate_preds_t1(preds): #return manipulate_preds(preds)
105
+ preds = list(filter(lambda x: x[0]>0.6,preds))
106
+
107
+ return preds
108
+
109
+ def manipulate_preds_t2(preds): return manipulate_preds_t1(preds)
110
+
111
+
112
+ if __name__ == '__main__':
113
+ USE_ACR = False
114
+ dataset_paths = [
115
+ 'MammoDatasets/AIIMS_C1/test/{0}/preds_frcnn_AIIMS_C1',
116
+ 'MammoDatasets/AIIMS_C2/test/{0}/preds_frcnn_AIIMS_C2',
117
+ 'MammoDatasets/AIIMS_C3/test/{0}/preds_frcnn_AIIMS_C3',
118
+ 'MammoDatasets/AIIMS_C4/test/{0}/preds_frcnn_AIIMS_C4',
119
+ 'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_bilateral_BILATERAL',
120
+ 'MammoDatasets/AIIMS_highres_reliable/test/{0}/preds_frcnn_16',
121
+ ]
122
+
123
+
124
+ st = int(sys.argv[1])
125
+ end = len(dataset_paths) - int(sys.argv[2])
126
+ allowed = [i for i in range(st,end)]
127
+ allowed = [0,1,2,3,4,5]
128
+
129
+ OUT_FILE = 'contrast_frcnn.txt'
130
+ if OUT_FILE is not None:
131
+ fol = os.path.split(OUT_FILE)[0]
132
+ if fol != '':
133
+ os.makedirs(fol, exist_ok=True)
134
+
135
+ acr_cat = json.load(open('aiims_categories.json','r'))
136
+ print(allowed)
137
+
138
+ mp_dict = {
139
+ 'preds_frcnn_AIIMS_C3': manipulate_preds,
140
+ 'preds_frcnn_AIIMS_C4': manipulate_preds_4,
141
+ 'AIIMS_T2': manipulate_preds_t2,
142
+ 'AIIMS_T1': manipulate_preds_t1,
143
+ }
144
+
145
+ image_dict = get_image_dict(dataset_paths, allowed = allowed, USE_ACR = USE_ACR, acr_cat = acr_cat, mp_dict = mp_dict)
146
+
147
+ image_dict = apply_merge(image_dict, METHOD = 'nms') # or wbf
148
+
149
+ if OUT_FILE:
150
+ pickle.dump(image_dict, open(OUT_FILE.replace('.txt','.pkl'),'wb'))
151
+ senses, fps = calc_froc_from_dict(image_dict, fps_req = [0.025,0.05,0.1,0.15,0.2,0.3,1.],save_to=OUT_FILE)
152
+ pretty_print_fps(senses, fps)
DenseMammogram/model_utils.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torchvision.transforms as T
3
+ import cv2
4
+ from tqdm import tqdm
5
+ import detection.transforms as transforms
6
+ from dataloaders import get_direction
7
+
8
+ def generate_predictions_bilateral(model,device,testpath_,cor_dict,dset='aiims',preds_folder='preds_new'):
9
+ transform = T.Compose([T.ToPILImage(),T.ToTensor()])
10
+ model.eval()
11
+ for label in ['mal','ben']:
12
+ testpath = os.path.join(testpath_,label)
13
+ # testpath = os.path.join(dataset_path,'Training', 'train',label)
14
+ testimg = os.path.join(testpath, 'images')
15
+
16
+ #preds_folder = 'preds_new'
17
+ os.makedirs(os.path.join(testpath, preds_folder),exist_ok=True)
18
+
19
+ if not os.path.exists(os.path.join(testpath,preds_folder)):
20
+ os.makedirs(os.path.join(testpath+preds_folder),exist_ok = True)
21
+
22
+ for file in tqdm(os.listdir(testimg)):
23
+ img1 = cv2.imread(os.path.join(testimg,file))
24
+ img1 = transform(img1)
25
+ # if False:
26
+ if(os.path.join(testimg,file) in cor_dict and os.path.isfile(cor_dict[os.path.join(testimg,file)])):
27
+ print('Using Bilateral')
28
+ img2 = cv2.imread(cor_dict[os.path.join(testimg,file)])
29
+ img2 = transform(img2)
30
+ if(get_direction(dset,file)==1):
31
+ img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
32
+
33
+ images = [img1.to(device),img2.to(device)]
34
+ output = model([images])[0]
35
+ img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
36
+ else:
37
+ img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
38
+
39
+ images = [img1.to(device),img2.to(device)]
40
+ output = model([images])[0]
41
+ else:
42
+ print('Using FRCNN')
43
+ output = model.frcnn([img1.to(device)])[0]
44
+ #output = model.frcnn([img1.to(device)])[0]
45
+ boxes = output['boxes']
46
+ scores = output['scores']
47
+ labels = output['labels']
48
+ f = open(os.path.join(testpath,preds_folder,file[:-4]+'.txt'),'w')
49
+ for i in range(len(boxes)):
50
+ box = boxes[i].detach().cpu().numpy()
51
+ #f.write('{} {} {} {} {} {}\n'.format(scores[i].item(),labels[i].item(),box[0],box[1],box[2],box[3]))
52
+ f.write('{} {} {} {} {}\n'.format(scores[i].item(),box[0],box[1],box[2],box[3]))
53
+
54
+
55
+ def generate_predictions(model,device,testpath_,preds_folder='preds_frcnn'):
56
+ transform = T.Compose([T.ToPILImage(),T.ToTensor()])
57
+ model.eval()
58
+ for label in ['mal','ben']:
59
+ testpath = os.path.join(testpath_,label)
60
+ # testpath = os.path.join(dataset_path,'Training', 'train',label)
61
+ testimg = os.path.join(testpath, 'images')
62
+
63
+ #preds_folder = 'preds_new'
64
+ os.makedirs(os.path.join(testpath, preds_folder),exist_ok=True)
65
+
66
+ if not os.path.exists(os.path.join(testpath,preds_folder)):
67
+ os.makedirs(os.path.join(testpath+preds_folder),exist_ok = True)
68
+
69
+ for file in tqdm(os.listdir(testimg)):
70
+ im = cv2.imread(os.path.join(testimg,file))
71
+ if file == 'Mass-Training_P_00444_LEFT_CC.png':
72
+ print('Test this')
73
+ continue
74
+ im = transform(im)
75
+
76
+ output = model([im.to(device)])[0]
77
+ boxes = output['boxes'] #/ FAC
78
+ scores = output['scores']
79
+ labels = output['labels']
80
+ f = open(os.path.join(testpath,preds_folder,file[:-4]+'.txt'),'w')
81
+ for i in range(len(boxes)):
82
+ box = boxes[i].detach().cpu().numpy()
83
+ f.write('{} {} {} {} {}\n'.format(scores[i].item(),box[0],box[1],box[2],box[3]))
DenseMammogram/models.py ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, OrderedDict, Tuple
2
+ import warnings
3
+ import numpy as np
4
+ import pandas as pd
5
+ import cv2
6
+ import os
7
+ from torch.nn.modules.conv import Conv2d
8
+ from torch.utils.data.dataset import ConcatDataset
9
+ from tqdm import tqdm
10
+ import argparse
11
+ from torch.utils.data import Dataset,DataLoader
12
+ import torch
13
+ import torch.nn as nn
14
+ from torchvision import models
15
+ import detection.transforms as transforms
16
+ import torchvision.transforms as T
17
+ import detection.utils as utils
18
+ import torch.nn.functional as F
19
+ import shutil
20
+ import json
21
+ from detection.engine import train_one_epoch, evaluate
22
+ from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
23
+ import torch.multiprocessing
24
+ import copy
25
+ from torchvision.ops import MultiScaleRoIAlign
26
+ from torchvision.models.detection.roi_heads import RoIHeads
27
+
28
+
29
+
30
+
31
+ # First we will create the FRCNN model
32
+ def get_FRCNN_model(num_classes=1):
33
+ model = models.detection.fasterrcnn_resnet50_fpn(pretrained=True,trainable_backbone_layers=3,min_size=1800,max_size=3600,image_std=(1.0,1.0,1.0),box_score_thresh=0.001)
34
+ # get number of input features for the classifier
35
+ in_features = model.roi_heads.box_predictor.cls_score.in_features
36
+ # replace the pre-trained head with a new one
37
+ model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes+1)
38
+ return model
39
+
40
+ # Some utility heads for Bilateral Model
41
+
42
+ class RoIpool(nn.Module):
43
+
44
+ def __init__(self,pool):
45
+ super().__init__()
46
+ self.box_roi_pool1 = copy.deepcopy(pool)
47
+ self.box_roi_pool2 = copy.deepcopy(pool)
48
+
49
+
50
+ def forward(self,features,proposals,image_shapes):
51
+ x = self.box_roi_pool1(features[0],proposals,image_shapes)
52
+ y = self.box_roi_pool2(features[1],proposals,image_shapes)
53
+ z = torch.cat((x,y),dim=1)
54
+ return z
55
+
56
+ class TwoMLPHead(nn.Module):
57
+ """
58
+ Standard heads for FPN-based models
59
+ Args:
60
+ in_channels (int): number of input channels
61
+ representation_size (int): size of the intermediate representation
62
+ """
63
+
64
+ def __init__(self, in_channels=None, representation_size=None):
65
+ super().__init__()
66
+
67
+ self.fc6 = nn.Linear(in_channels, representation_size)
68
+ self.fc7 = nn.Linear(representation_size, representation_size)
69
+
70
+ def forward(self, x):
71
+ x = x.flatten(start_dim=1)
72
+
73
+ x = F.relu(self.fc6(x))
74
+ x = F.relu(self.fc7(x))
75
+ return x
76
+
77
+ # Next the bilateral model
78
+
79
+ class Bilateral_model(nn.Module):
80
+
81
+ def __init__(self,frcnn_model):
82
+ super().__init__()
83
+ self.frcnn = frcnn_model
84
+ self.transform = copy.deepcopy(frcnn_model.transform)
85
+ self.backbone1 = copy.deepcopy(frcnn_model.backbone)
86
+ self.backbone2 = copy.deepcopy(frcnn_model.backbone)
87
+ self.rpn = copy.deepcopy(frcnn_model.rpn)
88
+ for param in self.rpn.parameters():
89
+ param.requires_grad = False
90
+ for param in self.backbone1.parameters():
91
+ param.requires_grad = False
92
+ for param in self.backbone2.parameters():
93
+ param.requires_grad = False
94
+ box_roi_pool = RoIpool(frcnn_model.roi_heads.box_roi_pool)
95
+ box_head = TwoMLPHead(512*7*7,1024)
96
+ box_predictor = copy.deepcopy(frcnn_model.roi_heads.box_predictor)
97
+ box_score_thresh=0.001
98
+ box_nms_thresh=0.5
99
+ box_detections_per_img=100
100
+ box_fg_iou_thresh=0.5
101
+ box_bg_iou_thresh=0.5
102
+ box_batch_size_per_image=512
103
+ box_positive_fraction=0.25
104
+ bbox_reg_weights=None
105
+ self.roi_heads = RoIHeads(
106
+ # Box
107
+ box_roi_pool,
108
+ box_head,
109
+ box_predictor,
110
+ box_fg_iou_thresh,
111
+ box_bg_iou_thresh,
112
+ box_batch_size_per_image,
113
+ box_positive_fraction,
114
+ bbox_reg_weights,
115
+ box_score_thresh,
116
+ box_nms_thresh,
117
+ box_detections_per_img,
118
+ )
119
+
120
+ @torch.jit.unused
121
+ def eager_outputs(self, losses, detections):
122
+ if self.training:
123
+ return losses
124
+
125
+ return detections
126
+
127
+
128
+ def forward(self, images, targets=None):
129
+ """
130
+ Args:
131
+ images (list[Tensor(tuples)]): images to be processed
132
+ targets (list[Dict[str, Tensor]]): ground-truth boxes present in the image (optional)
133
+ Returns:
134
+ result (list[BoxList] or dict[Tensor]): the output from the model.
135
+ During training, it returns a dict[Tensor] which contains the losses.
136
+ During testing, it returns list[BoxList] contains additional fields
137
+ like `scores`, `labels` and `mask` (for Mask R-CNN models).
138
+ """
139
+ if self.training and targets is None:
140
+ raise ValueError("In training mode, targets should be passed")
141
+ if self.training:
142
+ assert targets is not None
143
+ for target in targets:
144
+ boxes = target["boxes"]
145
+ if isinstance(boxes, torch.Tensor):
146
+ if len(boxes.shape) != 2 or boxes.shape[-1] != 4:
147
+ raise ValueError(f"Expected target boxes to be a tensor of shape [N, 4], got {boxes.shape}.")
148
+ else:
149
+ raise ValueError(f"Expected target boxes to be of type Tensor, got {type(boxes)}.")
150
+
151
+ original_image_sizes: List[Tuple[int, int]] = []
152
+ for img in images:
153
+ val = img[0].shape[-2:]
154
+ assert len(val) == 2
155
+ original_image_sizes.append((val[0], val[1]))
156
+ images1 = [img[0] for img in images]
157
+ images2 = [img[1] for img in images]
158
+ targets2 = copy.deepcopy(targets)
159
+ #print(images1.shape)
160
+ #print(images2.shape)
161
+ images1, targets = self.transform(images1, targets)
162
+ images2, targets2 = self.transform(images2, targets2)
163
+
164
+ # Check for degenerate boxes
165
+ # TODO: Move this to a function
166
+ if targets is not None:
167
+ for target_idx, target in enumerate(targets):
168
+ boxes = target["boxes"]
169
+ degenerate_boxes = boxes[:, 2:] <= boxes[:, :2]
170
+ if degenerate_boxes.any():
171
+ # print the first degenerate box
172
+ bb_idx = torch.where(degenerate_boxes.any(dim=1))[0][0]
173
+ degen_bb: List[float] = boxes[bb_idx].tolist()
174
+ raise ValueError(
175
+ "All bounding boxes should have positive height and width."
176
+ f" Found invalid box {degen_bb} for target at index {target_idx}."
177
+ )
178
+
179
+ features1 = self.backbone1(images1.tensors)
180
+ features2 = self.backbone2(images2.tensors)
181
+ #print(self.backbone1.out_channels)
182
+ if isinstance(features1, torch.Tensor):
183
+ features1 = OrderedDict([("0", features1)])
184
+ if isinstance(features2, torch.Tensor):
185
+ features2 = OrderedDict([("0", features2)])
186
+ proposals, proposal_losses = self.rpn(images1, features1, targets)
187
+ features = {0:features1,1:features2}
188
+ detections, detector_losses = self.roi_heads(features, proposals, images1.image_sizes, targets)
189
+ detections = self.transform.postprocess(detections, images1.image_sizes, original_image_sizes) # type: ignore[operator]
190
+
191
+ losses = {}
192
+ losses.update(detector_losses)
193
+ losses.update(proposal_losses)
194
+
195
+ if torch.jit.is_scripting():
196
+ if not self._has_warned:
197
+ warnings.warn("RCNN always returns a (Losses, Detections) tuple in scripting")
198
+ self._has_warned = True
199
+ return losses, detections
200
+ else:
201
+ return self.eager_outputs(losses, detections)
DenseMammogram/plot_froc.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import numpy as np
3
+
4
+ ####### PARAMETERS TO ADJUST #######
5
+
6
+ # Specify the files generated from merge_nms and plot corresponding graphs
7
+ base_fol = 'normal_test'
8
+ input_files = {
9
+ f'thresh_uni.txt' : 'Thresh + Uni',
10
+ f'thresh_nouni.txt' : 'Thresh + NoUni',
11
+ }
12
+ save_file = 'uni_vs_nouni.png'
13
+ # TITLE = 'Thresh + Contrast + Bilateral vs Contrast + Bilateral FROC Comparison (Normal Test)'
14
+ TITLE = 'Uni vs NoUni FROC Comparison (Normal Test)'
15
+
16
+ SHOW = False
17
+ CLIP_FPI = 1.2
18
+ MIN_CLIP_FPI = 0.0
19
+ ####################################
20
+
21
+ def plot_froc(input_files, save_file, TITLE = 'FRCNN vs BILATERAL FROC', SHOW = False, CLIP_FPI = 1.2):
22
+ for file in input_files:
23
+ lines = open(file).readlines()
24
+ x = np.array([float(line.split()[0]) for line in lines])
25
+ y = np.array([float(line.split()[1]) for line in lines])
26
+ y = y[x<CLIP_FPI]
27
+ x = x[x<CLIP_FPI]
28
+ y = y[MIN_CLIP_FPI<x]
29
+ x = x[MIN_CLIP_FPI<x]
30
+ plt.plot(x, y, label = input_files[file])
31
+ plt.legend()
32
+
33
+ plt.title(TITLE)
34
+ plt.xlabel('Average False Positive Per Image')
35
+ plt.ylabel('Sensetivity')
36
+
37
+ if SHOW:
38
+ plt.show()
39
+ plt.savefig(save_file)
40
+ plt.clf()
41
+
42
+ if __name__ == '__main__':
43
+ plot_froc(input_files, save_file, TITLE, SHOW, CLIP_FPI)
DenseMammogram/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.2
2
+ tqdm==4.62.3
3
+ torchvision==0.11.3
4
+ scipy==1.7.3
5
+ scikit-learn==1.0.2
6
+ PyYAML==6.0
7
+ Pillow==8.4.0
8
+ pandas==1.4.0
9
+ matplotlib==3.5.1
10
+ numpy
11
+ easydict==1.9
DenseMammogram/train_bilateral.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from advanced_logger import LogPriority
4
+ from dataloaders import get_bilateral_dataloaders
5
+ from models import get_FRCNN_model, Bilateral_model
6
+ from detection.engine import evaluate_loss, train_one_epoch_simplified
7
+
8
+ def main(cfg, experimenter):
9
+
10
+ LR = cfg['LR']
11
+ WEIGHT_DECAY = cfg['WEIGHT_DECAY']
12
+ NUM_EPOCHS = cfg['NUM_EPOCHS']
13
+ BATCH_SIZE = cfg['BATCH_SIZE']
14
+
15
+
16
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
17
+ frcnn_model = get_FRCNN_model().to(device)
18
+ frcnn_model.load_state_dict(torch.load(cfg['FRCNN_MODEL_PATH']))
19
+
20
+ model = Bilateral_model(frcnn_model).to(device)
21
+
22
+ train_loader, val_loader = get_bilateral_dataloaders(experimenter.config, batch_size = BATCH_SIZE, data_dir = experimenter.config['DATA_DIR'])
23
+
24
+ if cfg["OPTIM"] == "SGD":
25
+ optimizer = torch.optim.SGD(filter(lambda p: p.requires_grad , model.roi_heads.parameters()),lr=LR,momentum=0.9,weight_decay=WEIGHT_DECAY)
26
+ elif cfg["OPTIM"] == "ADAM":
27
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr = LR, weight_decay = WEIGHT_DECAY)
28
+ elif cfg["OPTIM"] == "ADAGRAD":
29
+ optimizer = torch.optim.Adagrad(filter(lambda p: p.requires_grad, model.roi_heads.parameters()), lr = LR, weight_decay = WEIGHT_DECAY)
30
+ for epoch in range(NUM_EPOCHS):
31
+ experimenter.start_epoch()
32
+ train_one_epoch_simplified(model, optimizer, train_loader, device, epoch, experimenter = experimenter,optimizer_backbone=None)
33
+ loss = evaluate_loss(model, device, val_loader, experimenter = experimenter)
34
+ experimenter.log('Validation Loss: {}'.format(loss), priority = LogPriority.MEDIUM)
35
+
36
+ experimenter.end_epoch(loss, model, device)
37
+ experimenter.save_model(model)
38
+ experimenter.generate_predictions(model, device)
39
+
40
+
41
+ if __name__ == '__main__':
42
+ from experimenter import Experimenter
43
+ import os
44
+ os.environ['CUDA_VISIBLE_DEVICES'] = '4'
45
+ cfg_file = 'configs/default.cfg'
46
+ experimenter = Experimenter(cfg_file)
47
+ main(experimenter.config['BILATERAL'], experimenter)
DenseMammogram/train_frcnn.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import math
3
+ from advanced_logger import LogPriority
4
+ from dataloaders import get_FRCNN_dataloaders
5
+ from models import get_FRCNN_model
6
+ from detection.engine import evaluate_loss, evaluate_simplified, train_one_epoch_simplified, evaluate_simplified
7
+
8
+ def main(cfg, experimenter):
9
+
10
+ LR = cfg['LR']
11
+ WEIGHT_DECAY = cfg['WEIGHT_DECAY']
12
+ NUM_EPOCHS = cfg['NUM_EPOCHS']
13
+ BATCH_SIZE = cfg['BATCH_SIZE']
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ model = get_FRCNN_model().to(device)
17
+ train_loader, val_loader = get_FRCNN_dataloaders(experimenter.config, batch_size=BATCH_SIZE, data_dir = experimenter.config['DATA_DIR'])
18
+ optimizer = torch.optim.SGD(model.parameters(),lr=LR,momentum=0.9,weight_decay=WEIGHT_DECAY)
19
+
20
+ for epoch in range(NUM_EPOCHS):
21
+ experimenter.start_epoch()
22
+ train_one_epoch_simplified(model, optimizer, train_loader, device, epoch, experimenter = experimenter)
23
+ evaluate_simplified(model, val_loader, device=device, experimenter = experimenter)
24
+ loss = evaluate_loss(model, device, val_loader, experimenter = experimenter)
25
+ experimenter.log('Validation Loss: {}'.format(loss), priority = LogPriority.MEDIUM)
26
+ experimenter.end_epoch(loss, model = model, device = device)
27
+ experimenter.save_model(model)
28
+ experimenter.generate_predictions(model, device)
29
+
30
+ if __name__ == '__main__':
31
+ from experimenter import Experimenter
32
+ cfg_file = 'configs/AIIMS_C1.cfg'
33
+ experimenter = Experimenter(cfg_file)
34
+ main(experimenter.config['FRCNN'], experimenter)
DenseMammogram/utils.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import shutil
2
+ import os
3
+ from os.path import join
4
+
5
+
6
+ class AverageMeter:
7
+ """Computes and stores the average and current value"""
8
+ def __init__(self):
9
+ self.reset()
10
+
11
+ def reset(self):
12
+ self.val = 0
13
+ self.avg = 0
14
+ self.sum = 0
15
+ self.count = 0
16
+
17
+ def update(self, val, n=1):
18
+ self.val = val
19
+ self.sum += val * n
20
+ self.count += n
21
+ self.avg = self.sum / self.count
22
+
23
+ def create_backup(folders = None, files = None, backup_dir = 'experiments'):
24
+ if folders is None:
25
+ folders = ['.', 'corr_lists','detection']
26
+ if files is None:
27
+ files = ['.py', '.txt', '.json','.cfg']
28
+
29
+ for folder in folders:
30
+ if not os.path.isdir(folder):
31
+ continue
32
+ for file in os.listdir(folder):
33
+ if file.endswith(tuple(files)):
34
+ if folder != '.':
35
+ src = join(folder, file)
36
+ dest = join(backup_dir, folder, file)
37
+ else:
38
+ src = file
39
+ dest = join(backup_dir, file)
40
+ os.makedirs(os.path.split(dest)[0], exist_ok=True)
41
+ shutil.copy(src, dest)
app.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from model import predict
3
+ import cv2
4
+
5
+
6
+ with gr.Blocks() as demo:
7
+ with gr.Column():
8
+ title = "<h1 style='margin-bottom: -10px; text-align: center'>Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts</h1>"
9
+ # gr.HTML(title)
10
+ gr.Markdown(
11
+ "<h1 style='text-align: center; margin-bottom: 1rem'>"
12
+ + title
13
+ + "</h1>"
14
+ )
15
+
16
+ description = "<p style='font-size: 14px; margin: 5px; font-weight: w300; text-align: center'> <a href='' style='text-decoration:none' target='_blank'>Krithika Rangarajan<sup>*</sup>, </a> <a href='https://github.com/Pranjal2041' style='text-decoration:none' target='_blank'>Pranjal Aggarwal<sup>*</sup>, </a> <a href='' style='text-decoration:none' target='_blank'>Dhruv Kumar Gupta, </a> <a href='' style='text-decoration:none' target='_blank'>Rohan Dhanakshirur, </a> <a href='' style='text-decoration:none' target='_blank'>Akhil Baby, </a> <a href='' style='text-decoration:none' target='_blank'>Chandan Pal, </a> <a href='' style='text-decoration:none' target='_blank'>Arun Kumar Gupta, </a> <a href='' style='text-decoration:none' target='_blank'>Smriti Hari, </a> <a href='' style='text-decoration:none' target='_blank'>Subhashis Banerjee, </a> <a href='' style='text-decoration:none' target='_blank'>Chetan Arora, </a> </p>" \
17
+ + "<p style='font-size: 16px; margin: 5px; font-weight: w600; text-align: center'> <a href='https://link.springer.com/article/10.1007/s00330-023-09717-7' target='_blank'>Publication</a> | <a href='https://github.com/Pranjal2041/DenseMammogram' target='_blank'>Website</a> | <a href='https://github.com/Pranjal2041/DenseMammogram' target='_blank'>Github Repo</a></p>" \
18
+ + "<p style='text-align: center; margin: 5px; font-size: 14px; font-weight: w300;'> \
19
+ Deep learning suffers from some problems similar to human radiologists, such as poor sensitivity to detection of isodense, obscure masses or cancers in dense breasts. Traditional radiology teaching can be incorporated into the deep learning approach to tackle these problems in the network. Our method suggests collaborative network design, and incorporates core radiology principles resulting in SOTA results. You can use this demo to run inference by providing bilateral mammogram images. To get started, you can try one of the preset examples. \
20
+ </p>" \
21
+ + "<p style='text-align: center; font-size: 14px; margin: 5px; font-weight: w300;'> [Note: Inference on CPU may take upto 2 minutes. On a GPU, inference time is approximately 1s.]</p>"
22
+ # gr.HTML(description)
23
+ gr.Markdown(description)
24
+
25
+ # head_html = gr.HTML('''
26
+ # <h1>
27
+ # Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts
28
+ # </h1>
29
+ # <p style='text-align: center;'>
30
+ # Give bilateral mammograms(both left and right sides), and let our model find the cancers!
31
+ # </p>
32
+
33
+ # <p style='text-align: center;'>
34
+ # This is an official demo for our paper:
35
+ # `Deep Learning for Detection of iso-dense, obscure masses in mammographically dense breasts`.
36
+ # Check out the paper and code for details!
37
+ # </p>
38
+ # ''')
39
+
40
+ # gr.Markdown(
41
+ # """
42
+ # [![report](https://img.shields.io/badge/arxiv-report-red)](https://arxiv.org/abs/) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/)
43
+ # """)
44
+
45
+ def generate_preds(img1, img2):
46
+ print(img1, img2)
47
+ print(img1, img2)
48
+ img_out1 = predict(img1, img2)
49
+ if img_out1.shape[1] < img_out1.shape[2]:
50
+ ratio = img_out1.shape[2] / 800
51
+ else:
52
+ ratio = img_out1.shape[1] / 800
53
+ img_out1 = cv2.resize(img_out1, (0,0), fx=1 / ratio, fy=1 / ratio)
54
+ img_out2 = predict(img2, img1, baseIsLeft = False)
55
+ if img_out2.shape[1] < img_out2.shape[2]:
56
+ ratio = img_out2.shape[2] / 800
57
+ else:
58
+ ratio = img_out2.shape[1] / 800
59
+ img_out2 = cv2.resize(img_out2, (0,0), fx= 1 / ratio, fy= 1 / ratio)
60
+
61
+ cv2.imwrite('img_out1.jpg', img_out1)
62
+ cv2.imwrite('img_out2.jpg', img_out2)
63
+
64
+
65
+ return 'img_out1.jpg', 'img_out2.jpg'
66
+
67
+ with gr.Column():
68
+ with gr.Row(variant = 'panel'):
69
+
70
+ with gr.Column(variant = 'panel'):
71
+ img1 = gr.Image(type="filepath", label="Left Image" )
72
+ img2 = gr.Image(type="filepath", label="Right Image")
73
+ # with gr.Row():
74
+ # sub_btn = gr.Button("Predict!", variant="primary")
75
+
76
+ with gr.Column(variant = 'panel'):
77
+ # img_out1 = gr.inputs.Image(type="file", label="Output Left Image")
78
+ # img_out2 = gr.inputs.Image(type="file", label="Output for Right Image")
79
+ img_out1 = gr.Image(type="filepath", label="Output for Left Image", shape = None)
80
+ img_out1.style(height=250 * 2)
81
+
82
+ with gr.Column(variant = 'panel'):
83
+ img_out2 = gr.Image(type="filepath", label="Output for Right Image", shape = None)
84
+ img_out2.style(height=250 * 2)
85
+
86
+ with gr.Row():
87
+ sub_btn = gr.Button("Predict!", variant="primary")
88
+
89
+ gr.Examples([[f'sample_images/img{idx}_l.jpg', f'sample_images/img{idx}_r.jpg'] for idx in range(1,6)], inputs = [img1, img2])
90
+
91
+ sub_btn.click(fn = lambda x,y: generate_preds(x,y), inputs = [img1, img2], outputs = [img_out1, img_out2])
92
+
93
+ # sub_btn.click(fn = lambda x: gr.update(visible = True), inputs = [sub_btn], outputs = [img_out1, img_out2])
94
+
95
+ # gr.Examples(
96
+
97
+ # )
98
+
99
+
100
+ # interface.render()
101
+ # Object Detection Interface
102
+
103
+ # def generate_predictions(img1, img2):
104
+ # return img1
105
+
106
+ # interface = gr.Interface(
107
+ # fn=generate_predictions,
108
+ # inputs=[gr.inputs.Image(type="pil", label="Left Image"), gr.inputs.Image(type="pil", label="Right Image")],
109
+ # outputs=[gr.outputs.Image(type="pil", label="Output Image")],
110
+ # title="Object Detection",
111
+ # description="This model is trained on DenseMammogram dataset. It can detect objects in images. Try it out!",
112
+ # allow_flagging = False
113
+ # ).launch(share = True, show_api=False)
114
+
115
+
116
+ if __name__ == '__main__':
117
+ demo.launch(share = True, show_api=False)
img_out1.jpg ADDED

Git LFS Details

  • SHA256: 278c18719edfb89968f1c3b8018d89565a985959a4ed23753db4ae8347381826
  • Pointer size: 131 Bytes
  • Size of remote file: 376 kB
img_out2.jpg ADDED

Git LFS Details

  • SHA256: 4f2da3a653a186a4056c956372a4f3caa5936bf2c85c05ecb8b29e84eb637e4f
  • Pointer size: 131 Bytes
  • Size of remote file: 300 kB
model.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append('DenseMammogram')
3
+
4
+ import torch
5
+
6
+ from models import get_FRCNN_model, Bilateral_model
7
+
8
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
9
+ frcnn_model = get_FRCNN_model().to(device)
10
+ bilat_model = Bilateral_model(frcnn_model).to(device)
11
+
12
+ FRCNN_PATH = 'pretrained_models/frcnn/frcnn_models/frcnn_model.pth'
13
+ BILAR_PATH = 'pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth'
14
+
15
+ frcnn_model.load_state_dict(torch.load(FRCNN_PATH, map_location=device))
16
+ bilat_model.load_state_dict(torch.load(BILAR_PATH, map_location=device))
17
+
18
+ import os
19
+ import torchvision.transforms as T
20
+ import cv2
21
+ from tqdm import tqdm
22
+ import detection.transforms as transforms
23
+ from dataloaders import get_direction
24
+
25
+ def predict(left_file, right_file, threshold = 0.80, baseIsLeft = True):
26
+ model = bilat_model
27
+ with torch.no_grad():
28
+ transform = T.Compose([T.ToPILImage(),T.ToTensor()])
29
+ model.eval()
30
+ # First is left, then right
31
+ img1 = cv2.imread(left_file)
32
+ img1 = transform(img1)
33
+ img2 = cv2.imread(right_file)
34
+ img2 = transform(img2)
35
+
36
+ if baseIsLeft:
37
+ img1,_ = transforms.RandomHorizontalFlip(1.0)(img1)
38
+ else:
39
+ img2,_ = transforms.RandomHorizontalFlip(1.0)(img2)
40
+
41
+
42
+ images = [img1.to(device),img2.to(device)]
43
+ output = model([images])[0]
44
+ if baseIsLeft:
45
+ img1,output = transforms.RandomHorizontalFlip(1.0)(img1,output)
46
+
47
+ image = cv2.imread(left_file)
48
+ for b,s,l in zip(output['boxes'], output['scores'], output['labels']):
49
+ # Convert img1 tensor to numpy array
50
+ if l == 1 and s > threshold:
51
+ # Draw the bounding boxes
52
+ b = b.detach().cpu().numpy().astype(int)
53
+ # return image, b
54
+ cv2.rectangle(image, (b[0], b[1]), (b[2], b[3]), (0, 255, 0), 2)
55
+ # Print the % probability just above the box
56
+ cv2.putText(image, 'Cancer: '+str(round(round(s.item(), 2) * 100, 1)) + '%', (b[0], b[1] - 40), cv2.FONT_HERSHEY_SIMPLEX, 3.6, (36,255,12), 6)
57
+ return image
pretrained_models/AIIMS_C1/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e4253bd5cda58b57e1ed38cbaadd7fa7698cbc47bcd4c795f27cf0a63a7da669
3
+ size 165725683
pretrained_models/AIIMS_C2/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:07ca463a86317a4db3f3ed24358ddf292701ea2a0daf67b966ac325e7d0bebae
3
+ size 165725683
pretrained_models/AIIMS_C3/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:51ec560b1b56b9199480dee4eaaa10f45b4b96feab9397dd90f4eb05f21fd6d5
3
+ size 165725683
pretrained_models/AIIMS_C4/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d18b23c2a1e06a11a27ebd77e87dbb6b27d54e88d92fc55d58c64957b8cdfcfb
3
+ size 165725683
pretrained_models/AIIMS_T1/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d8a1d133d3629e9c717070a66e1f2f2f846daca6765097622c2fe9f95c5a513
3
+ size 165725683
pretrained_models/AIIMS_T2/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5db00c682eec86bb2b4e764b64feffa26774643dae780bf3cf81313f5ca6f8de
3
+ size 165725683
pretrained_models/BILATERAL/bilateral_models/bilateral_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dce00a005fd102839f17c490b4a58191e92e99965b1ac7e323b71b0e75043d37
3
+ size 490558451
pretrained_models/frcnn/frcnn_models/frcnn_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e92090fd249484577db1c9e2560c82abddffd4c62203195bf8c35a32beeed4ad
3
+ size 165725683
requirements.txt ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch==1.10.2
3
+ tqdm==4.62.3
4
+ torchvision==0.11.3
5
+ scipy==1.7.3
6
+ scikit-learn==1.0.2
7
+ PyYAML==6.0
8
+ Pillow==8.4.0
9
+ pandas==1.4.0
10
+ matplotlib==3.5.1
11
+ numpy
12
+ easydict==1.9