{ "imports": [ "$import functools", "$import glob", "$import scripts" ], "bundle_root": ".", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "ckpt_path": "$@bundle_root + '/models/model.pt'", "dataset_dir": "./MedNIST/Hand", "datalist": "$list(sorted(glob.glob(@dataset_dir + '/*.jpeg')))", "latent_size": 64, "discriminator": { "_target_": "Discriminator", "in_shape": [ 1, 64, 64 ], "channels": [ 8, 16, 32, 64, 1 ], "strides": [ 2, 2, 2, 2, 1 ], "num_res_units": 1, "kernel_size": 5 }, "generator": { "_target_": "Generator", "latent_shape": "@latent_size", "start_shape": [ 64, 8, 8 ], "channels": [ 32, 16, 8, 1 ], "strides": [ 2, 2, 2, 1 ] }, "dnetwork": "$@discriminator.apply(monai.networks.normal_init).to(@device)", "gnetwork": "$@generator.apply(monai.networks.normal_init).to(@device)", "preprocessing": { "_target_": "Compose", "transforms": [ { "_target_": "LoadImaged", "keys": "reals" }, { "_target_": "EnsureChannelFirstd", "keys": "reals" }, { "_target_": "ScaleIntensityd", "keys": "reals" }, { "_target_": "RandRotated", "keys": "reals", "range_x": "$np.pi/12", "prob": 0.5, "keep_size": true }, { "_target_": "RandFlipd", "keys": "reals", "spatial_axis": 0, "prob": 0.5 }, { "_target_": "RandZoomd", "keys": "reals", "min_zoom": 0.9, "max_zoom": 1.1, "prob": 0.5 }, { "_target_": "EnsureTyped", "keys": "reals" } ] }, "real_dataset": { "_target_": "CacheDataset", "data": "$[{'reals': i} for i in @datalist]", "transform": "@preprocessing" }, "real_dataloader": { "_target_": "DataLoader", "dataset": "@real_dataset", "batch_size": 600, "shuffle": true, "num_workers": 12 }, "doptimizer": { "_target_": "torch.optim.Adam", "params": "$@dnetwork.parameters()", "lr": 0.0002, "betas": [ 0.5, 0.999 ] }, "goptimizer": { "_target_": "torch.optim.Adam", "params": "$@gnetwork.parameters()", "lr": 0.0002, "betas": [ 0.5, 0.999 ] }, "handlers": [ { "_target_": "CheckpointSaver", "save_dir": "$@bundle_root + '/models'", "save_dict": { "model": "@gnetwork" }, "save_interval": 0, "save_final": true, "epoch_level": true } ], "trainer": { "_target_": "GanTrainer", "device": "@device", "max_epochs": 50, "train_data_loader": "@real_dataloader", "g_network": "@gnetwork", "g_optimizer": "@goptimizer", "g_loss_function": "$functools.partial(scripts.losses.generator_loss, disc_net=@dnetwork)", "d_network": "@dnetwork", "d_optimizer": "@doptimizer", "d_loss_function": "$functools.partial(scripts.losses.discriminator_loss, disc_net=@dnetwork)", "d_train_steps": 5, "g_update_latents": true, "latent_shape": "@latent_size", "key_train_metric": "$None", "train_handlers": "@handlers" }, "training": [ "$@gnetwork.conv.add_module('activation', torch.nn.Sigmoid())", "$@trainer.run()" ] }