monai
medical
katielink commited on
Commit
f6cc1e0
1 Parent(s): 2208e51

use detection inferer

Browse files
configs/inference.json CHANGED
@@ -13,7 +13,7 @@
13
  "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
14
  "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
15
  "amp": true,
16
- "val_patch_size": [
17
  512,
18
  512,
19
  192
@@ -67,7 +67,7 @@
67
  "detector_ops": [
68
  "$@detector.set_target_keys(box_key='box', label_key='label')",
69
  "$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
70
- "$@detector.set_sliding_window_inferer(roi_size=@val_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
71
  ],
72
  "preprocessing": {
73
  "_target_": "Compose",
@@ -135,14 +135,8 @@
135
  "collate_fn": "$monai.data.utils.no_collation"
136
  },
137
  "inferer": {
138
- "_target_": "SlidingWindowInferer",
139
- "roi_size": [
140
- 240,
141
- 240,
142
- 160
143
- ],
144
- "sw_batch_size": 1,
145
- "overlap": 0.5
146
  },
147
  "postprocessing": {
148
  "_target_": "Compose",
@@ -203,7 +197,8 @@
203
  "_requires_": "@detector_ops",
204
  "device": "@device",
205
  "val_data_loader": "@dataloader",
206
- "detector": "@detector",
 
207
  "val_handlers": "@handlers",
208
  "amp": "@amp"
209
  },
 
13
  "test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
14
  "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
15
  "amp": true,
16
+ "infer_patch_size": [
17
  512,
18
  512,
19
  192
 
67
  "detector_ops": [
68
  "$@detector.set_target_keys(box_key='box', label_key='label')",
69
  "$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
70
+ "$@detector.set_sliding_window_inferer(roi_size=@infer_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
71
  ],
72
  "preprocessing": {
73
  "_target_": "Compose",
 
135
  "collate_fn": "$monai.data.utils.no_collation"
136
  },
137
  "inferer": {
138
+ "_target_": "scripts.detection_inferer.RetinaNetInferer",
139
+ "detector": "@detector"
 
 
 
 
 
 
140
  },
141
  "postprocessing": {
142
  "_target_": "Compose",
 
197
  "_requires_": "@detector_ops",
198
  "device": "@device",
199
  "val_data_loader": "@dataloader",
200
+ "network": "@network",
201
+ "inferer": "@inferer",
202
  "val_handlers": "@handlers",
203
  "amp": "@amp"
204
  },
configs/metadata.json CHANGED
@@ -1,7 +1,8 @@
1
  {
2
  "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
- "version": "0.4.5",
4
  "changelog": {
 
5
  "0.4.5": "fixed some small changes with formatting in readme",
6
  "0.4.4": "add data resource to readme",
7
  "0.4.3": "update val patch size to avoid warning in monai 1.0.1",
 
1
  {
2
  "schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
3
+ "version": "0.5.0",
4
  "changelog": {
5
+ "0.5.0": "use detection inferer",
6
  "0.4.5": "fixed some small changes with formatting in readme",
7
  "0.4.4": "add data resource to readme",
8
  "0.4.3": "update val patch size to avoid warning in monai 1.0.1",
configs/train.json CHANGED
@@ -399,6 +399,10 @@
399
  "num_workers": 2,
400
  "collate_fn": "$monai.data.utils.no_collation"
401
  },
 
 
 
 
402
  "handlers": [
403
  {
404
  "_target_": "StatsHandler",
@@ -435,7 +439,8 @@
435
  "_requires_": "@detector_ops",
436
  "device": "@device",
437
  "val_data_loader": "@validate#dataloader",
438
- "detector": "@detector",
 
439
  "key_val_metric": "@validate#key_metric",
440
  "val_handlers": "@validate#handlers",
441
  "amp": "@amp"
 
399
  "num_workers": 2,
400
  "collate_fn": "$monai.data.utils.no_collation"
401
  },
402
+ "inferer": {
403
+ "_target_": "scripts.detection_inferer.RetinaNetInferer",
404
+ "detector": "@detector"
405
+ },
406
  "handlers": [
407
  {
408
  "_target_": "StatsHandler",
 
439
  "_requires_": "@detector_ops",
440
  "device": "@device",
441
  "val_data_loader": "@validate#dataloader",
442
+ "network": "@network",
443
+ "inferer": "@validate#inferer",
444
  "key_val_metric": "@validate#key_metric",
445
  "val_handlers": "@validate#handlers",
446
  "amp": "@amp"
scripts/detection_inferer.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) MONAI Consortium
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ from typing import Any, List, Union
13
+
14
+ import numpy as np
15
+ import torch
16
+ from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
17
+ from monai.inferers.inferer import Inferer
18
+ from torch import Tensor
19
+
20
+
21
+ class RetinaNetInferer(Inferer):
22
+ """
23
+ RetinaNet Inferer takes RetinaNet as input
24
+
25
+ Args:
26
+ detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
27
+ map into boxes and classification scores.
28
+ args: other optional args to be passed to detector.
29
+ kwargs: other optional keyword args to be passed to detector.
30
+ """
31
+
32
+ def __init__(self, detector: RetinaNetDetector, *args, **kwargs) -> None:
33
+ Inferer.__init__(self)
34
+ self.detector = detector
35
+ self.sliding_window_size = None
36
+ if self.detector.inferer is not None:
37
+ if hasattr(self.detector.inferer, "roi_size"):
38
+ self.sliding_window_size = np.prod(self.detector.inferer.roi_size)
39
+
40
+ def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
41
+ """Unified callable function API of Inferers.
42
+ Args:
43
+ inputs: model input data for inference.
44
+ network: target detection network to execute inference.
45
+ supports callable that fullfilles requirements of network in
46
+ monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
47
+ args: optional args to be passed to ``network``.
48
+ kwargs: optional keyword args to be passed to ``network``.
49
+ """
50
+ self.detector.network = network
51
+ self.detector.training = self.detector.network.training
52
+
53
+ # if image smaller than sliding window roi size, no need to use sliding window inferer
54
+ # use sliding window inferer only when image is large
55
+ use_inferer = self.sliding_window_size is not None and not all(
56
+ [data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]
57
+ )
58
+
59
+ return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs)
scripts/evaluator.py CHANGED
@@ -13,19 +13,18 @@ from __future__ import annotations
13
 
14
  from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
 
16
- import numpy as np
17
  import torch
 
18
  from monai.config import IgniteInfo
19
- from monai.engines.evaluator import Evaluator
20
  from monai.engines.utils import IterationEvents, default_metric_cmp_fn
21
- from monai.inferers import Inferer
22
- from monai.networks.utils import eval_mode, train_mode
23
  from monai.transforms import Transform
24
  from monai.utils import ForwardMode, min_version, optional_import
25
  from monai.utils.enums import CommonKeys as Keys
26
- from monai.utils.module import look_up_option
27
  from torch.utils.data import DataLoader
28
 
 
 
29
  if TYPE_CHECKING:
30
  from ignite.engine import Engine, EventEnum
31
  from ignite.metrics import Metric
@@ -67,17 +66,17 @@ def detection_prepare_val_batch(
67
  return inputs, None
68
 
69
 
70
- class DetectionEvaluator(Evaluator):
71
  """
72
- Supervised detection evaluation method with image and label, inherits from ``Evaluator`` and ``Workflow``.
73
  Args:
74
  device: an object representing the device on which to run.
75
- val_data_loader: Ignite engine use data_loader to run, must be Iterable or torch.DataLoader.
76
- detector: detector to train in the trainer, should be regular PyTorch `torch.nn.Module`.
77
  epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
78
  non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
79
  with respect to the host. For other cases, this argument has no effect.
80
- prepare_batch: function to parse expected data (usually `image`,`box`, `label` and other detector args)
81
  from `engine.state.batch` for every iteration, for more details please refer to:
82
  https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
83
  iteration_update: the callable function for every iteration, expect to accept `engine`
@@ -117,12 +116,12 @@ class DetectionEvaluator(Evaluator):
117
  self,
118
  device: torch.device,
119
  val_data_loader: Iterable | DataLoader,
120
- detector: torch.nn.Module,
121
  epoch_length: int | None = None,
122
  non_blocking: bool = False,
123
  prepare_batch: Callable = detection_prepare_val_batch,
124
  iteration_update: Callable[[Engine, Any], Any] | None = None,
125
- inferer: Inferer | None = None,
126
  postprocessing: Transform | None = None,
127
  key_val_metric: dict[str, Metric] | None = None,
128
  additional_metrics: dict[str, Metric] | None = None,
@@ -139,10 +138,12 @@ class DetectionEvaluator(Evaluator):
139
  super().__init__(
140
  device=device,
141
  val_data_loader=val_data_loader,
 
142
  epoch_length=epoch_length,
143
  non_blocking=non_blocking,
144
  prepare_batch=prepare_batch,
145
  iteration_update=iteration_update,
 
146
  postprocessing=postprocessing,
147
  key_val_metric=key_val_metric,
148
  additional_metrics=additional_metrics,
@@ -157,16 +158,6 @@ class DetectionEvaluator(Evaluator):
157
  amp_kwargs=amp_kwargs,
158
  )
159
 
160
- self.detector = detector
161
-
162
- mode = look_up_option(mode, ForwardMode)
163
- if mode == ForwardMode.EVAL:
164
- self.mode = eval_mode
165
- elif mode == ForwardMode.TRAIN:
166
- self.mode = train_mode
167
- else:
168
- raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
169
-
170
  def _register_decollate(self):
171
  """
172
  Register the decollate operation for batch data, will execute after model forward and loss forward.
@@ -181,48 +172,3 @@ class DetectionEvaluator(Evaluator):
181
  if engine.state.output[k] is not None:
182
  output_list[i][k] = engine.state.output[k][i]
183
  engine.state.output = output_list
184
-
185
- def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
186
- """
187
- callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
188
- Return below items in a dictionary:
189
- - IMAGE: image Tensor data for model input, already moved to device.
190
- - LABEL: label Tensor data corresponding to the image, already moved to device.
191
- - PRED: prediction result of model.
192
- Args:
193
- engine: `SupervisedEvaluator` to execute operation for an iteration.
194
- batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
195
- Raises:
196
- ValueError: When ``batchdata`` is None.
197
- """
198
-
199
- if batchdata is None:
200
- raise ValueError("Must provide batch data for current iteration.")
201
-
202
- batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
203
- if len(batch) == 2:
204
- inputs, targets = batch
205
- args: tuple = ()
206
- kwargs: dict = {}
207
- else:
208
- inputs, targets, args, kwargs = batch
209
- # put iteration outputs into engine.state
210
- engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
211
-
212
- # execute forward computation
213
- sliding_window_size = np.prod(engine.detector.inferer.roi_size)
214
-
215
- with engine.mode(engine.detector):
216
-
217
- use_inferer = not all([val_data_i[0, ...].numel() < sliding_window_size for val_data_i in inputs])
218
-
219
- if engine.amp:
220
- with torch.cuda.amp.autocast(**engine.amp_kwargs):
221
- engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
222
- else:
223
- engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
224
-
225
- engine.fire_event(IterationEvents.FORWARD_COMPLETED)
226
- engine.fire_event(IterationEvents.MODEL_COMPLETED)
227
-
228
- return engine.state.output
 
13
 
14
  from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
15
 
 
16
  import torch
17
+ from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
18
  from monai.config import IgniteInfo
19
+ from monai.engines.evaluator import SupervisedEvaluator
20
  from monai.engines.utils import IterationEvents, default_metric_cmp_fn
 
 
21
  from monai.transforms import Transform
22
  from monai.utils import ForwardMode, min_version, optional_import
23
  from monai.utils.enums import CommonKeys as Keys
 
24
  from torch.utils.data import DataLoader
25
 
26
+ from .detection_inferer import RetinaNetInferer
27
+
28
  if TYPE_CHECKING:
29
  from ignite.engine import Engine, EventEnum
30
  from ignite.metrics import Metric
 
66
  return inputs, None
67
 
68
 
69
+ class DetectionEvaluator(SupervisedEvaluator):
70
  """
71
+ Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
72
  Args:
73
  device: an object representing the device on which to run.
74
+ val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
75
+ network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
76
  epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
77
  non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
78
  with respect to the host. For other cases, this argument has no effect.
79
+ prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
80
  from `engine.state.batch` for every iteration, for more details please refer to:
81
  https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
82
  iteration_update: the callable function for every iteration, expect to accept `engine`
 
116
  self,
117
  device: torch.device,
118
  val_data_loader: Iterable | DataLoader,
119
+ network: RetinaNetDetector,
120
  epoch_length: int | None = None,
121
  non_blocking: bool = False,
122
  prepare_batch: Callable = detection_prepare_val_batch,
123
  iteration_update: Callable[[Engine, Any], Any] | None = None,
124
+ inferer: RetinaNetInferer | None = None,
125
  postprocessing: Transform | None = None,
126
  key_val_metric: dict[str, Metric] | None = None,
127
  additional_metrics: dict[str, Metric] | None = None,
 
138
  super().__init__(
139
  device=device,
140
  val_data_loader=val_data_loader,
141
+ network=network,
142
  epoch_length=epoch_length,
143
  non_blocking=non_blocking,
144
  prepare_batch=prepare_batch,
145
  iteration_update=iteration_update,
146
+ inferer=inferer,
147
  postprocessing=postprocessing,
148
  key_val_metric=key_val_metric,
149
  additional_metrics=additional_metrics,
 
158
  amp_kwargs=amp_kwargs,
159
  )
160
 
 
 
 
 
 
 
 
 
 
 
161
  def _register_decollate(self):
162
  """
163
  Register the decollate operation for batch data, will execute after model forward and loss forward.
 
172
  if engine.state.output[k] is not None:
173
  output_list[i][k] = engine.state.output[k][i]
174
  engine.state.output = output_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
scripts/warmup_scheduler.py CHANGED
@@ -39,6 +39,7 @@ class GradualWarmupScheduler(_LRScheduler):
39
  super(GradualWarmupScheduler, self).__init__(optimizer)
40
 
41
  def get_lr(self):
 
42
  if self.last_epoch > self.total_epoch:
43
  if self.after_scheduler:
44
  if not self.finished:
 
39
  super(GradualWarmupScheduler, self).__init__(optimizer)
40
 
41
  def get_lr(self):
42
+ self.last_epoch = max(1, self.last_epoch) # to avoid epoch=0 thus lr=0
43
  if self.last_epoch > self.total_epoch:
44
  if self.after_scheduler:
45
  if not self.finished: