|
{ |
|
"ckpt_dir": "$@bundle_root + '/models'", |
|
"train_batch_size_img": 2, |
|
"train_batch_size_slice": 50, |
|
"lr": 5e-05, |
|
"train_patch_size": [ |
|
256, |
|
256 |
|
], |
|
"latent_shape": [ |
|
"@latent_channels", |
|
64, |
|
64 |
|
], |
|
"load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", |
|
"load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", |
|
"autoencoder": "$@autoencoder_def.to(@device)", |
|
"network_def": { |
|
"_target_": "generative.networks.nets.DiffusionModelUNet", |
|
"spatial_dims": "@spatial_dims", |
|
"in_channels": "@latent_channels", |
|
"out_channels": "@latent_channels", |
|
"num_channels": [ |
|
32, |
|
64, |
|
128, |
|
256 |
|
], |
|
"attention_levels": [ |
|
false, |
|
true, |
|
true, |
|
true |
|
], |
|
"num_head_channels": [ |
|
0, |
|
32, |
|
32, |
|
32 |
|
], |
|
"num_res_blocks": 2 |
|
}, |
|
"diffusion": "$@network_def.to(@device)", |
|
"optimizer": { |
|
"_target_": "torch.optim.Adam", |
|
"params": "$@diffusion.parameters()", |
|
"lr": "@lr" |
|
}, |
|
"lr_scheduler": { |
|
"_target_": "torch.optim.lr_scheduler.MultiStepLR", |
|
"optimizer": "@optimizer", |
|
"milestones": [ |
|
1000 |
|
], |
|
"gamma": 0.1 |
|
}, |
|
"scale_factor": "$scripts.utils.compute_scale_factor(@autoencoder,@train#dataloader,@device)", |
|
"noise_scheduler": { |
|
"_target_": "generative.networks.schedulers.DDPMScheduler", |
|
"_requires_": [ |
|
"@load_autoencoder" |
|
], |
|
"schedule": "scaled_linear_beta", |
|
"num_train_timesteps": 1000, |
|
"beta_start": 0.0015, |
|
"beta_end": 0.0195 |
|
}, |
|
"inferer": { |
|
"_target_": "generative.inferers.LatentDiffusionInferer", |
|
"scheduler": "@noise_scheduler", |
|
"scale_factor": "@scale_factor" |
|
}, |
|
"loss": { |
|
"_target_": "torch.nn.MSELoss" |
|
}, |
|
"train": { |
|
"crop_transforms": [ |
|
{ |
|
"_target_": "DivisiblePadd", |
|
"keys": "image", |
|
"k": [ |
|
32, |
|
32, |
|
1 |
|
] |
|
}, |
|
{ |
|
"_target_": "RandSpatialCropSamplesd", |
|
"keys": "image", |
|
"random_size": false, |
|
"roi_size": "$[@train_patch_size[0], @train_patch_size[1], 1]", |
|
"num_samples": "@train_batch_size_slice" |
|
}, |
|
{ |
|
"_target_": "SqueezeDimd", |
|
"keys": "image", |
|
"dim": 3 |
|
} |
|
], |
|
"preprocessing": { |
|
"_target_": "Compose", |
|
"transforms": "$@preprocessing_transforms + @train#crop_transforms" |
|
}, |
|
"dataset": { |
|
"_target_": "monai.apps.DecathlonDataset", |
|
"root_dir": "@dataset_dir", |
|
"task": "Task01_BrainTumour", |
|
"section": "training", |
|
"cache_rate": 1.0, |
|
"num_workers": 8, |
|
"download": false, |
|
"transform": "@train#preprocessing" |
|
}, |
|
"dataloader": { |
|
"_target_": "DataLoader", |
|
"dataset": "@train#dataset", |
|
"batch_size": "@train_batch_size_img", |
|
"shuffle": true, |
|
"num_workers": 0 |
|
}, |
|
"handlers": [ |
|
{ |
|
"_target_": "LrScheduleHandler", |
|
"lr_scheduler": "@lr_scheduler", |
|
"print_lr": true |
|
}, |
|
{ |
|
"_target_": "CheckpointSaver", |
|
"save_dir": "@ckpt_dir", |
|
"save_dict": { |
|
"model": "@diffusion" |
|
}, |
|
"save_interval": 0, |
|
"save_final": true, |
|
"epoch_level": true, |
|
"final_filename": "model.pt" |
|
}, |
|
{ |
|
"_target_": "StatsHandler", |
|
"tag_name": "train_diffusion_loss", |
|
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" |
|
}, |
|
{ |
|
"_target_": "TensorBoardStatsHandler", |
|
"log_dir": "@tf_dir", |
|
"tag_name": "train_diffusion_loss", |
|
"output_transform": "$lambda x: monai.handlers.from_engine(['loss'], first=True)(x)" |
|
} |
|
], |
|
"trainer": { |
|
"_target_": "scripts.ldm_trainer.LDMTrainer", |
|
"device": "@device", |
|
"max_epochs": 1000, |
|
"train_data_loader": "@train#dataloader", |
|
"network": "@diffusion", |
|
"autoencoder_model": "@autoencoder", |
|
"optimizer": "@optimizer", |
|
"loss_function": "@loss", |
|
"latent_shape": "@latent_shape", |
|
"inferer": "@inferer", |
|
"key_train_metric": "$None", |
|
"train_handlers": "@train#handlers" |
|
} |
|
}, |
|
"initialize": [ |
|
"$monai.utils.set_determinism(seed=0)" |
|
], |
|
"run": [ |
|
"@load_autoencoder", |
|
"$@autoencoder.eval()", |
|
"$print('scale factor:',@scale_factor)", |
|
"$@train#trainer.run()" |
|
] |
|
} |
|
|