|
--- |
|
device: "$torch.device(f'cuda:{dist.get_rank()}')" |
|
network: |
|
_target_: torch.nn.parallel.DistributedDataParallel |
|
module: "$@network_def.to(@device)" |
|
find_unused_parameters: true |
|
device_ids: |
|
- "@device" |
|
optimizer#lr: "$0.025*dist.get_world_size()" |
|
lr_scheduler#step_size: "$80*dist.get_world_size()" |
|
train#handlers: |
|
- _target_: LrScheduleHandler |
|
lr_scheduler: "@lr_scheduler" |
|
print_lr: true |
|
- _target_: ValidationHandler |
|
validator: "@validate#evaluator" |
|
epoch_level: true |
|
interval: "$10*dist.get_world_size()" |
|
- _target_: StatsHandler |
|
tag_name: train_loss |
|
output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
|
- _target_: TensorBoardStatsHandler |
|
log_dir: "@output_dir" |
|
tag_name: train_loss |
|
output_transform: "$monai.handlers.from_engine(['loss'], first=True)" |
|
train#trainer#max_epochs: "$400*dist.get_world_size()" |
|
train#trainer#train_handlers: "$@train#handlers[: -2 if dist.get_rank() > 0 else None]" |
|
validate#evaluator#val_handlers: "$None if dist.get_rank() > 0 else @validate#handlers" |
|
initialize: |
|
- "$import torch.distributed as dist" |
|
- "$dist.is_initialized() or dist.init_process_group(backend='nccl')" |
|
- "$torch.cuda.set_device(@device)" |
|
- "$monai.utils.set_determinism(seed=123)" |
|
- "$setattr(torch.backends.cudnn, 'benchmark', True)" |
|
run: |
|
- "$@train#trainer.run()" |
|
finalize: |
|
- "$dist.is_initialized() and dist.destroy_process_group()" |
|
train_data_partition: "$monai.data.partition_dataset(data=@train_datalist, num_partitions=dist.get_world_size(), |
|
shuffle=True, even_divisible=True,)[dist.get_rank()]" |
|
train#dataset: |
|
_target_: CacheDataset |
|
data: "@train_data_partition" |
|
transform: "@train#preprocessing" |
|
cache_rate: 1 |
|
num_workers: 4 |
|
val_data_partition: "$monai.data.partition_dataset(data=@val_datalist, num_partitions=dist.get_world_size(), |
|
shuffle=False, even_divisible=False,)[dist.get_rank()]" |
|
validate#dataset: |
|
_target_: CacheDataset |
|
data: "@val_data_partition" |
|
transform: "@validate#preprocessing" |
|
cache_rate: 1 |
|
num_workers: 4 |
|
|