mednist_gan / configs /train.json
katielink's picture
Update for 1.0
f8cfccc
raw
history blame
No virus
4.28 kB
{
"imports": [
"$import functools",
"$import glob",
"$import scripts"
],
"bundle_root": ".",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"ckpt_path": "$@bundle_root + '/models/model.pt'",
"dataset_dir": "./MedNIST/Hand",
"datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))",
"latent_size": 64,
"discriminator": {
"_target_": "Discriminator",
"in_shape": [
1,
64,
64
],
"channels": [
8,
16,
32,
64,
1
],
"strides": [
2,
2,
2,
2,
1
],
"num_res_units": 1,
"kernel_size": 5
},
"generator": {
"_target_": "Generator",
"latent_shape": "@latent_size",
"start_shape": [
64,
8,
8
],
"channels": [
32,
16,
8,
1
],
"strides": [
2,
2,
2,
1
]
},
"dnetwork": "$@discriminator.apply(monai.networks.normal_init).to(@device)",
"gnetwork": "$@generator.apply(monai.networks.normal_init).to(@device)",
"preprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "LoadImaged",
"keys": "reals"
},
{
"_target_": "EnsureChannelFirstd",
"keys": "reals"
},
{
"_target_": "ScaleIntensityd",
"keys": "reals"
},
{
"_target_": "RandRotated",
"keys": "reals",
"range_x": "$np.pi/12",
"prob": 0.5,
"keep_size": true
},
{
"_target_": "RandFlipd",
"keys": "reals",
"spatial_axis": 0,
"prob": 0.5
},
{
"_target_": "RandZoomd",
"keys": "reals",
"min_zoom": 0.9,
"max_zoom": 1.1,
"prob": 0.5
},
{
"_target_": "EnsureTyped",
"keys": "reals"
}
]
},
"real_dataset": {
"_target_": "CacheDataset",
"data": "$[{'reals': i} for i in @datalist]",
"transform": "@preprocessing"
},
"real_dataloader": {
"_target_": "DataLoader",
"dataset": "@real_dataset",
"batch_size": 600,
"shuffle": true,
"num_workers": 12
},
"doptimizer": {
"_target_": "torch.optim.Adam",
"params": "$@dnetwork.parameters()",
"lr": 0.0002,
"betas": [
0.5,
0.999
]
},
"goptimizer": {
"_target_": "torch.optim.Adam",
"params": "$@gnetwork.parameters()",
"lr": 0.0002,
"betas": [
0.5,
0.999
]
},
"handlers": [
{
"_target_": "CheckpointSaver",
"save_dir": "$@bundle_root + '/models'",
"save_dict": {
"model": "@gnetwork"
},
"save_interval": 0,
"save_final": true,
"epoch_level": true
}
],
"trainer": {
"_target_": "GanTrainer",
"device": "@device",
"max_epochs": 50,
"train_data_loader": "@real_dataloader",
"g_network": "@gnetwork",
"g_optimizer": "@goptimizer",
"g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork)",
"d_network": "@dnetwork",
"d_optimizer": "@doptimizer",
"d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)",
"d_train_steps": 5,
"g_update_latents": true,
"latent_shape": "@latent_size",
"key_train_metric": "$None",
"train_handlers": "@handlers"
},
"training": [
"$@gnetwork.conv.add_module('activation', torch.nn.Sigmoid())",
"$@trainer.run()"
]
}