mednist_gan / configs /inference.json
katielink's picture
Update for 1.0
f8cfccc
raw
history blame
No virus
2.4 kB
{
"imports": [
"$import glob"
],
"bundle_root": ".",
"device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')",
"ckpt_path": "$@bundle_root + '/models/model.pt'",
"output_dir": "./outputs",
"latent_size": 64,
"num_samples": 10,
"network_def": {
"_target_": "Generator",
"latent_shape": "@latent_size",
"start_shape": [
64,
8,
8
],
"channels": [
32,
16,
8,
1
],
"strides": [
2,
2,
2,
1
]
},
"network": "$@network_def.to(@device)",
"dataset": {
"_target_": "Dataset",
"data": "$[torch.rand(@latent_size) for i in range(@num_samples)]"
},
"dataloader": {
"_target_": "DataLoader",
"dataset": "@dataset",
"batch_size": 1,
"shuffle": false,
"num_workers": 0
},
"inferer": {
"_target_": "SimpleInferer"
},
"postprocessing": {
"_target_": "Compose",
"transforms": [
{
"_target_": "Activationsd",
"keys": "pred",
"sigmoid": true
},
{
"_target_": "ToTensord",
"keys": "pred",
"track_meta": false
},
{
"_target_": "SaveImaged",
"keys": "pred",
"output_dir": "@output_dir",
"output_ext": "png",
"separate_folder": false,
"scale": 255,
"output_dtype": "$np.uint8",
"meta_key_postfix": null
}
]
},
"handlers": [
{
"_target_": "CheckpointLoader",
"load_path": "@ckpt_path",
"load_dict": {
"model": "@network"
}
}
],
"evaluator": {
"_target_": "SupervisedEvaluator",
"device": "@device",
"val_data_loader": "@dataloader",
"network": "@network",
"inferer": "@inferer",
"postprocessing": "@postprocessing",
"prepare_batch": "$lambda batchdata, *_,**__: (batchdata.to(@device),None,(),{})",
"val_handlers": "@handlers"
},
"inferring": [
"$@evaluator.run()"
]
}