|
{ |
|
"imports": [ |
|
"$import functools", |
|
"$import glob", |
|
"$import scripts" |
|
], |
|
"bundle_root": ".", |
|
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", |
|
"ckpt_path": "$@bundle_root + '/models/model.pt'", |
|
"dataset_dir": "./MedNIST/Hand", |
|
"datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))", |
|
"latent_size": 64, |
|
"discriminator": { |
|
"_target_": "Discriminator", |
|
"in_shape": [ |
|
1, |
|
64, |
|
64 |
|
], |
|
"channels": [ |
|
8, |
|
16, |
|
32, |
|
64, |
|
1 |
|
], |
|
"strides": [ |
|
2, |
|
2, |
|
2, |
|
2, |
|
1 |
|
], |
|
"num_res_units": 1, |
|
"kernel_size": 5 |
|
}, |
|
"generator": { |
|
"_target_": "Generator", |
|
"latent_shape": "@latent_size", |
|
"start_shape": [ |
|
64, |
|
8, |
|
8 |
|
], |
|
"channels": [ |
|
32, |
|
16, |
|
8, |
|
1 |
|
], |
|
"strides": [ |
|
2, |
|
2, |
|
2, |
|
1 |
|
] |
|
}, |
|
"dnetwork": "$@discriminator.apply(monai.networks.normal_init).to(@device)", |
|
"gnetwork": "$@generator.apply(monai.networks.normal_init).to(@device)", |
|
"preprocessing": { |
|
"_target_": "Compose", |
|
"transforms": [ |
|
{ |
|
"_target_": "LoadImaged", |
|
"keys": "reals" |
|
}, |
|
{ |
|
"_target_": "EnsureChannelFirstd", |
|
"keys": "reals" |
|
}, |
|
{ |
|
"_target_": "ScaleIntensityd", |
|
"keys": "reals" |
|
}, |
|
{ |
|
"_target_": "RandRotated", |
|
"keys": "reals", |
|
"range_x": "$np.pi/12", |
|
"prob": 0.5, |
|
"keep_size": true |
|
}, |
|
{ |
|
"_target_": "RandFlipd", |
|
"keys": "reals", |
|
"spatial_axis": 0, |
|
"prob": 0.5 |
|
}, |
|
{ |
|
"_target_": "RandZoomd", |
|
"keys": "reals", |
|
"min_zoom": 0.9, |
|
"max_zoom": 1.1, |
|
"prob": 0.5 |
|
}, |
|
{ |
|
"_target_": "EnsureTyped", |
|
"keys": "reals" |
|
} |
|
] |
|
}, |
|
"real_dataset": { |
|
"_target_": "CacheDataset", |
|
"data": "$[{'reals': i} for i in @datalist]", |
|
"transform": "@preprocessing" |
|
}, |
|
"real_dataloader": { |
|
"_target_": "DataLoader", |
|
"dataset": "@real_dataset", |
|
"batch_size": 600, |
|
"shuffle": true, |
|
"num_workers": 12 |
|
}, |
|
"doptimizer": { |
|
"_target_": "torch.optim.Adam", |
|
"params": "$@dnetwork.parameters()", |
|
"lr": 0.0002, |
|
"betas": [ |
|
0.5, |
|
0.999 |
|
] |
|
}, |
|
"goptimizer": { |
|
"_target_": "torch.optim.Adam", |
|
"params": "$@gnetwork.parameters()", |
|
"lr": 0.0002, |
|
"betas": [ |
|
0.5, |
|
0.999 |
|
] |
|
}, |
|
"handlers": [ |
|
{ |
|
"_target_": "CheckpointSaver", |
|
"save_dir": "$@bundle_root + '/models'", |
|
"save_dict": { |
|
"model": "@gnetwork" |
|
}, |
|
"save_interval": 0, |
|
"save_final": true, |
|
"epoch_level": true |
|
} |
|
], |
|
"trainer": { |
|
"_target_": "GanTrainer", |
|
"device": "@device", |
|
"max_epochs": 50, |
|
"train_data_loader": "@real_dataloader", |
|
"g_network": "@gnetwork", |
|
"g_optimizer": "@goptimizer", |
|
"g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork)", |
|
"d_network": "@dnetwork", |
|
"d_optimizer": "@doptimizer", |
|
"d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)", |
|
"d_train_steps": 5, |
|
"g_update_latents": true, |
|
"latent_shape": "@latent_size", |
|
"key_train_metric": "$None", |
|
"train_handlers": "@handlers" |
|
}, |
|
"training": [ |
|
"$@gnetwork.conv.add_module('activation', torch.nn.Sigmoid())", |
|
"$@trainer.run()" |
|
] |
|
} |
|
|