use detection inferer
Browse files- configs/inference.json +6 -11
- configs/metadata.json +2 -1
- configs/train.json +6 -1
- scripts/detection_inferer.py +59 -0
- scripts/evaluator.py +13 -67
- scripts/warmup_scheduler.py +1 -0
configs/inference.json
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
|
14 |
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
15 |
"amp": true,
|
16 |
-
"
|
17 |
512,
|
18 |
512,
|
19 |
192
|
@@ -67,7 +67,7 @@
|
|
67 |
"detector_ops": [
|
68 |
"$@detector.set_target_keys(box_key='box', label_key='label')",
|
69 |
"$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
70 |
-
"$@detector.set_sliding_window_inferer(roi_size=@
|
71 |
],
|
72 |
"preprocessing": {
|
73 |
"_target_": "Compose",
|
@@ -135,14 +135,8 @@
|
|
135 |
"collate_fn": "$monai.data.utils.no_collation"
|
136 |
},
|
137 |
"inferer": {
|
138 |
-
"_target_": "
|
139 |
-
"
|
140 |
-
240,
|
141 |
-
240,
|
142 |
-
160
|
143 |
-
],
|
144 |
-
"sw_batch_size": 1,
|
145 |
-
"overlap": 0.5
|
146 |
},
|
147 |
"postprocessing": {
|
148 |
"_target_": "Compose",
|
@@ -203,7 +197,8 @@
|
|
203 |
"_requires_": "@detector_ops",
|
204 |
"device": "@device",
|
205 |
"val_data_loader": "@dataloader",
|
206 |
-
"
|
|
|
207 |
"val_handlers": "@handlers",
|
208 |
"amp": "@amp"
|
209 |
},
|
|
|
13 |
"test_datalist": "$monai.data.load_decathlon_datalist(@data_list_file_path, is_segmentation=True, data_list_key='validation', base_dir=@data_file_base_dir)",
|
14 |
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
|
15 |
"amp": true,
|
16 |
+
"infer_patch_size": [
|
17 |
512,
|
18 |
512,
|
19 |
192
|
|
|
67 |
"detector_ops": [
|
68 |
"$@detector.set_target_keys(box_key='box', label_key='label')",
|
69 |
"$@detector.set_box_selector_parameters(score_thresh=0.02,topk_candidates_per_level=1000,nms_thresh=0.22,detections_per_img=300)",
|
70 |
+
"$@detector.set_sliding_window_inferer(roi_size=@infer_patch_size,overlap=0.25,sw_batch_size=1,mode='constant',device='cpu')"
|
71 |
],
|
72 |
"preprocessing": {
|
73 |
"_target_": "Compose",
|
|
|
135 |
"collate_fn": "$monai.data.utils.no_collation"
|
136 |
},
|
137 |
"inferer": {
|
138 |
+
"_target_": "scripts.detection_inferer.RetinaNetInferer",
|
139 |
+
"detector": "@detector"
|
|
|
|
|
|
|
|
|
|
|
|
|
140 |
},
|
141 |
"postprocessing": {
|
142 |
"_target_": "Compose",
|
|
|
197 |
"_requires_": "@detector_ops",
|
198 |
"device": "@device",
|
199 |
"val_data_loader": "@dataloader",
|
200 |
+
"network": "@network",
|
201 |
+
"inferer": "@inferer",
|
202 |
"val_handlers": "@handlers",
|
203 |
"amp": "@amp"
|
204 |
},
|
configs/metadata.json
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
{
|
2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
-
"version": "0.
|
4 |
"changelog": {
|
|
|
5 |
"0.4.5": "fixed some small changes with formatting in readme",
|
6 |
"0.4.4": "add data resource to readme",
|
7 |
"0.4.3": "update val patch size to avoid warning in monai 1.0.1",
|
|
|
1 |
{
|
2 |
"schema": "https://github.com/Project-MONAI/MONAI-extra-test-data/releases/download/0.8.1/meta_schema_20220324.json",
|
3 |
+
"version": "0.5.0",
|
4 |
"changelog": {
|
5 |
+
"0.5.0": "use detection inferer",
|
6 |
"0.4.5": "fixed some small changes with formatting in readme",
|
7 |
"0.4.4": "add data resource to readme",
|
8 |
"0.4.3": "update val patch size to avoid warning in monai 1.0.1",
|
configs/train.json
CHANGED
@@ -399,6 +399,10 @@
|
|
399 |
"num_workers": 2,
|
400 |
"collate_fn": "$monai.data.utils.no_collation"
|
401 |
},
|
|
|
|
|
|
|
|
|
402 |
"handlers": [
|
403 |
{
|
404 |
"_target_": "StatsHandler",
|
@@ -435,7 +439,8 @@
|
|
435 |
"_requires_": "@detector_ops",
|
436 |
"device": "@device",
|
437 |
"val_data_loader": "@validate#dataloader",
|
438 |
-
"
|
|
|
439 |
"key_val_metric": "@validate#key_metric",
|
440 |
"val_handlers": "@validate#handlers",
|
441 |
"amp": "@amp"
|
|
|
399 |
"num_workers": 2,
|
400 |
"collate_fn": "$monai.data.utils.no_collation"
|
401 |
},
|
402 |
+
"inferer": {
|
403 |
+
"_target_": "scripts.detection_inferer.RetinaNetInferer",
|
404 |
+
"detector": "@detector"
|
405 |
+
},
|
406 |
"handlers": [
|
407 |
{
|
408 |
"_target_": "StatsHandler",
|
|
|
439 |
"_requires_": "@detector_ops",
|
440 |
"device": "@device",
|
441 |
"val_data_loader": "@validate#dataloader",
|
442 |
+
"network": "@network",
|
443 |
+
"inferer": "@validate#inferer",
|
444 |
"key_val_metric": "@validate#key_metric",
|
445 |
"val_handlers": "@validate#handlers",
|
446 |
"amp": "@amp"
|
scripts/detection_inferer.py
ADDED
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) MONAI Consortium
|
2 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3 |
+
# you may not use this file except in compliance with the License.
|
4 |
+
# You may obtain a copy of the License at
|
5 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
6 |
+
# Unless required by applicable law or agreed to in writing, software
|
7 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
8 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
9 |
+
# See the License for the specific language governing permissions and
|
10 |
+
# limitations under the License.
|
11 |
+
|
12 |
+
from typing import Any, List, Union
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
|
17 |
+
from monai.inferers.inferer import Inferer
|
18 |
+
from torch import Tensor
|
19 |
+
|
20 |
+
|
21 |
+
class RetinaNetInferer(Inferer):
|
22 |
+
"""
|
23 |
+
RetinaNet Inferer takes RetinaNet as input
|
24 |
+
|
25 |
+
Args:
|
26 |
+
detector: the RetinaNetDetector that converts network output BxCxMxN or BxCxMxNxP
|
27 |
+
map into boxes and classification scores.
|
28 |
+
args: other optional args to be passed to detector.
|
29 |
+
kwargs: other optional keyword args to be passed to detector.
|
30 |
+
"""
|
31 |
+
|
32 |
+
def __init__(self, detector: RetinaNetDetector, *args, **kwargs) -> None:
|
33 |
+
Inferer.__init__(self)
|
34 |
+
self.detector = detector
|
35 |
+
self.sliding_window_size = None
|
36 |
+
if self.detector.inferer is not None:
|
37 |
+
if hasattr(self.detector.inferer, "roi_size"):
|
38 |
+
self.sliding_window_size = np.prod(self.detector.inferer.roi_size)
|
39 |
+
|
40 |
+
def __call__(self, inputs: Union[List[Tensor], Tensor], network: torch.nn.Module, *args: Any, **kwargs: Any):
|
41 |
+
"""Unified callable function API of Inferers.
|
42 |
+
Args:
|
43 |
+
inputs: model input data for inference.
|
44 |
+
network: target detection network to execute inference.
|
45 |
+
supports callable that fullfilles requirements of network in
|
46 |
+
monai.apps.detection.networks.retinanet_detector.RetinaNetDetector``
|
47 |
+
args: optional args to be passed to ``network``.
|
48 |
+
kwargs: optional keyword args to be passed to ``network``.
|
49 |
+
"""
|
50 |
+
self.detector.network = network
|
51 |
+
self.detector.training = self.detector.network.training
|
52 |
+
|
53 |
+
# if image smaller than sliding window roi size, no need to use sliding window inferer
|
54 |
+
# use sliding window inferer only when image is large
|
55 |
+
use_inferer = self.sliding_window_size is not None and not all(
|
56 |
+
[data_i[0, ...].numel() < self.sliding_window_size for data_i in inputs]
|
57 |
+
)
|
58 |
+
|
59 |
+
return self.detector(inputs, use_inferer=use_inferer, *args, **kwargs)
|
scripts/evaluator.py
CHANGED
@@ -13,19 +13,18 @@ from __future__ import annotations
|
|
13 |
|
14 |
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
|
16 |
-
import numpy as np
|
17 |
import torch
|
|
|
18 |
from monai.config import IgniteInfo
|
19 |
-
from monai.engines.evaluator import
|
20 |
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
21 |
-
from monai.inferers import Inferer
|
22 |
-
from monai.networks.utils import eval_mode, train_mode
|
23 |
from monai.transforms import Transform
|
24 |
from monai.utils import ForwardMode, min_version, optional_import
|
25 |
from monai.utils.enums import CommonKeys as Keys
|
26 |
-
from monai.utils.module import look_up_option
|
27 |
from torch.utils.data import DataLoader
|
28 |
|
|
|
|
|
29 |
if TYPE_CHECKING:
|
30 |
from ignite.engine import Engine, EventEnum
|
31 |
from ignite.metrics import Metric
|
@@ -67,17 +66,17 @@ def detection_prepare_val_batch(
|
|
67 |
return inputs, None
|
68 |
|
69 |
|
70 |
-
class DetectionEvaluator(
|
71 |
"""
|
72 |
-
Supervised detection evaluation method with image and label, inherits from ``
|
73 |
Args:
|
74 |
device: an object representing the device on which to run.
|
75 |
-
val_data_loader: Ignite engine use data_loader to run, must be Iterable
|
76 |
-
|
77 |
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
|
78 |
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
79 |
with respect to the host. For other cases, this argument has no effect.
|
80 |
-
prepare_batch: function to parse expected data (usually `image
|
81 |
from `engine.state.batch` for every iteration, for more details please refer to:
|
82 |
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
83 |
iteration_update: the callable function for every iteration, expect to accept `engine`
|
@@ -117,12 +116,12 @@ class DetectionEvaluator(Evaluator):
|
|
117 |
self,
|
118 |
device: torch.device,
|
119 |
val_data_loader: Iterable | DataLoader,
|
120 |
-
|
121 |
epoch_length: int | None = None,
|
122 |
non_blocking: bool = False,
|
123 |
prepare_batch: Callable = detection_prepare_val_batch,
|
124 |
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
125 |
-
inferer:
|
126 |
postprocessing: Transform | None = None,
|
127 |
key_val_metric: dict[str, Metric] | None = None,
|
128 |
additional_metrics: dict[str, Metric] | None = None,
|
@@ -139,10 +138,12 @@ class DetectionEvaluator(Evaluator):
|
|
139 |
super().__init__(
|
140 |
device=device,
|
141 |
val_data_loader=val_data_loader,
|
|
|
142 |
epoch_length=epoch_length,
|
143 |
non_blocking=non_blocking,
|
144 |
prepare_batch=prepare_batch,
|
145 |
iteration_update=iteration_update,
|
|
|
146 |
postprocessing=postprocessing,
|
147 |
key_val_metric=key_val_metric,
|
148 |
additional_metrics=additional_metrics,
|
@@ -157,16 +158,6 @@ class DetectionEvaluator(Evaluator):
|
|
157 |
amp_kwargs=amp_kwargs,
|
158 |
)
|
159 |
|
160 |
-
self.detector = detector
|
161 |
-
|
162 |
-
mode = look_up_option(mode, ForwardMode)
|
163 |
-
if mode == ForwardMode.EVAL:
|
164 |
-
self.mode = eval_mode
|
165 |
-
elif mode == ForwardMode.TRAIN:
|
166 |
-
self.mode = train_mode
|
167 |
-
else:
|
168 |
-
raise ValueError(f"unsupported mode: {mode}, should be 'eval' or 'train'.")
|
169 |
-
|
170 |
def _register_decollate(self):
|
171 |
"""
|
172 |
Register the decollate operation for batch data, will execute after model forward and loss forward.
|
@@ -181,48 +172,3 @@ class DetectionEvaluator(Evaluator):
|
|
181 |
if engine.state.output[k] is not None:
|
182 |
output_list[i][k] = engine.state.output[k][i]
|
183 |
engine.state.output = output_list
|
184 |
-
|
185 |
-
def _iteration(self, engine, batchdata: dict[str, torch.Tensor]):
|
186 |
-
"""
|
187 |
-
callback function for the Supervised Evaluation processing logic of 1 iteration in Ignite Engine.
|
188 |
-
Return below items in a dictionary:
|
189 |
-
- IMAGE: image Tensor data for model input, already moved to device.
|
190 |
-
- LABEL: label Tensor data corresponding to the image, already moved to device.
|
191 |
-
- PRED: prediction result of model.
|
192 |
-
Args:
|
193 |
-
engine: `SupervisedEvaluator` to execute operation for an iteration.
|
194 |
-
batchdata: input data for this iteration, usually can be dictionary or tuple of Tensor data.
|
195 |
-
Raises:
|
196 |
-
ValueError: When ``batchdata`` is None.
|
197 |
-
"""
|
198 |
-
|
199 |
-
if batchdata is None:
|
200 |
-
raise ValueError("Must provide batch data for current iteration.")
|
201 |
-
|
202 |
-
batch = engine.prepare_batch(batchdata, engine.state.device, engine.non_blocking, **engine.to_kwargs)
|
203 |
-
if len(batch) == 2:
|
204 |
-
inputs, targets = batch
|
205 |
-
args: tuple = ()
|
206 |
-
kwargs: dict = {}
|
207 |
-
else:
|
208 |
-
inputs, targets, args, kwargs = batch
|
209 |
-
# put iteration outputs into engine.state
|
210 |
-
engine.state.output = {Keys.IMAGE: inputs, Keys.LABEL: targets}
|
211 |
-
|
212 |
-
# execute forward computation
|
213 |
-
sliding_window_size = np.prod(engine.detector.inferer.roi_size)
|
214 |
-
|
215 |
-
with engine.mode(engine.detector):
|
216 |
-
|
217 |
-
use_inferer = not all([val_data_i[0, ...].numel() < sliding_window_size for val_data_i in inputs])
|
218 |
-
|
219 |
-
if engine.amp:
|
220 |
-
with torch.cuda.amp.autocast(**engine.amp_kwargs):
|
221 |
-
engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
|
222 |
-
else:
|
223 |
-
engine.state.output[Keys.PRED] = engine.detector(inputs, use_inferer=use_inferer)
|
224 |
-
|
225 |
-
engine.fire_event(IterationEvents.FORWARD_COMPLETED)
|
226 |
-
engine.fire_event(IterationEvents.MODEL_COMPLETED)
|
227 |
-
|
228 |
-
return engine.state.output
|
|
|
13 |
|
14 |
from typing import TYPE_CHECKING, Any, Callable, Dict, Iterable, List, Optional, Sequence, Tuple, Union
|
15 |
|
|
|
16 |
import torch
|
17 |
+
from monai.apps.detection.networks.retinanet_detector import RetinaNetDetector
|
18 |
from monai.config import IgniteInfo
|
19 |
+
from monai.engines.evaluator import SupervisedEvaluator
|
20 |
from monai.engines.utils import IterationEvents, default_metric_cmp_fn
|
|
|
|
|
21 |
from monai.transforms import Transform
|
22 |
from monai.utils import ForwardMode, min_version, optional_import
|
23 |
from monai.utils.enums import CommonKeys as Keys
|
|
|
24 |
from torch.utils.data import DataLoader
|
25 |
|
26 |
+
from .detection_inferer import RetinaNetInferer
|
27 |
+
|
28 |
if TYPE_CHECKING:
|
29 |
from ignite.engine import Engine, EventEnum
|
30 |
from ignite.metrics import Metric
|
|
|
66 |
return inputs, None
|
67 |
|
68 |
|
69 |
+
class DetectionEvaluator(SupervisedEvaluator):
|
70 |
"""
|
71 |
+
Supervised detection evaluation method with image and label, inherits from ``SupervisedEvaluator`` and ``Workflow``.
|
72 |
Args:
|
73 |
device: an object representing the device on which to run.
|
74 |
+
val_data_loader: Ignite engine use data_loader to run, must be Iterable, typically be torch.DataLoader.
|
75 |
+
network: detector to evaluate in the evaluator, should be regular PyTorch `torch.nn.Module`.
|
76 |
epoch_length: number of iterations for one epoch, default to `len(val_data_loader)`.
|
77 |
non_blocking: if True and this copy is between CPU and GPU, the copy may occur asynchronously
|
78 |
with respect to the host. For other cases, this argument has no effect.
|
79 |
+
prepare_batch: function to parse expected data (usually `image`, `label` and other network args)
|
80 |
from `engine.state.batch` for every iteration, for more details please refer to:
|
81 |
https://pytorch.org/ignite/generated/ignite.engine.create_supervised_trainer.html.
|
82 |
iteration_update: the callable function for every iteration, expect to accept `engine`
|
|
|
116 |
self,
|
117 |
device: torch.device,
|
118 |
val_data_loader: Iterable | DataLoader,
|
119 |
+
network: RetinaNetDetector,
|
120 |
epoch_length: int | None = None,
|
121 |
non_blocking: bool = False,
|
122 |
prepare_batch: Callable = detection_prepare_val_batch,
|
123 |
iteration_update: Callable[[Engine, Any], Any] | None = None,
|
124 |
+
inferer: RetinaNetInferer | None = None,
|
125 |
postprocessing: Transform | None = None,
|
126 |
key_val_metric: dict[str, Metric] | None = None,
|
127 |
additional_metrics: dict[str, Metric] | None = None,
|
|
|
138 |
super().__init__(
|
139 |
device=device,
|
140 |
val_data_loader=val_data_loader,
|
141 |
+
network=network,
|
142 |
epoch_length=epoch_length,
|
143 |
non_blocking=non_blocking,
|
144 |
prepare_batch=prepare_batch,
|
145 |
iteration_update=iteration_update,
|
146 |
+
inferer=inferer,
|
147 |
postprocessing=postprocessing,
|
148 |
key_val_metric=key_val_metric,
|
149 |
additional_metrics=additional_metrics,
|
|
|
158 |
amp_kwargs=amp_kwargs,
|
159 |
)
|
160 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
161 |
def _register_decollate(self):
|
162 |
"""
|
163 |
Register the decollate operation for batch data, will execute after model forward and loss forward.
|
|
|
172 |
if engine.state.output[k] is not None:
|
173 |
output_list[i][k] = engine.state.output[k][i]
|
174 |
engine.state.output = output_list
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
scripts/warmup_scheduler.py
CHANGED
@@ -39,6 +39,7 @@ class GradualWarmupScheduler(_LRScheduler):
|
|
39 |
super(GradualWarmupScheduler, self).__init__(optimizer)
|
40 |
|
41 |
def get_lr(self):
|
|
|
42 |
if self.last_epoch > self.total_epoch:
|
43 |
if self.after_scheduler:
|
44 |
if not self.finished:
|
|
|
39 |
super(GradualWarmupScheduler, self).__init__(optimizer)
|
40 |
|
41 |
def get_lr(self):
|
42 |
+
self.last_epoch = max(1, self.last_epoch) # to avoid epoch=0 thus lr=0
|
43 |
if self.last_epoch > self.total_epoch:
|
44 |
if self.after_scheduler:
|
45 |
if not self.finished:
|