{ "imports": [ "$import torch", "$from datetime import datetime", "$from pathlib import Path", "$from PIL import Image", "$from scripts.utils import visualize_2d_image" ], "bundle_root": ".", "model_dir": "$@bundle_root + '/models'", "output_dir": "$@bundle_root + '/output'", "create_output_dir": "$Path(@output_dir).mkdir(exist_ok=True)", "device": "$torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')", "output_postfix": "$datetime.now().strftime('sample_%Y%m%d_%H%M%S')", "channel": 0, "spatial_dims": 2, "image_channels": 1, "latent_channels": 1, "latent_shape": [ "@latent_channels", 64, 64 ], "autoencoder_def": { "_target_": "generative.networks.nets.AutoencoderKL", "spatial_dims": "@spatial_dims", "in_channels": "@image_channels", "out_channels": "@image_channels", "latent_channels": "@latent_channels", "num_channels": [ 64, 128, 256 ], "num_res_blocks": 2, "norm_num_groups": 32, "norm_eps": 1e-06, "attention_levels": [ false, false, false ], "with_encoder_nonlocal_attn": true, "with_decoder_nonlocal_attn": true }, "network_def": { "_target_": "generative.networks.nets.DiffusionModelUNet", "spatial_dims": "@spatial_dims", "in_channels": "@latent_channels", "out_channels": "@latent_channels", "num_channels": [ 32, 64, 128, 256 ], "attention_levels": [ false, true, true, true ], "num_head_channels": [ 0, 32, 32, 32 ], "num_res_blocks": 2 }, "load_autoencoder_path": "$@bundle_root + '/models/model_autoencoder.pt'", "load_autoencoder": "$@autoencoder_def.load_state_dict(torch.load(@load_autoencoder_path))", "autoencoder": "$@autoencoder_def.to(@device)", "load_diffusion_path": "$@model_dir + '/model.pt'", "load_diffusion": "$@network_def.load_state_dict(torch.load(@load_diffusion_path))", "diffusion": "$@network_def.to(@device)", "noise_scheduler": { "_target_": "generative.networks.schedulers.DDIMScheduler", "_requires_": [ "@load_diffusion", "@load_autoencoder" ], "num_train_timesteps": 1000, "beta_start": 0.0015, "beta_end": 0.0195, "schedule": "scaled_linear_beta", "clip_sample": false }, "noise": "$torch.randn([1]+@latent_shape).to(@device)", "set_timesteps": "$@noise_scheduler.set_timesteps(num_inference_steps=50)", "inferer": { "_target_": "scripts.ldm_sampler.LDMSampler", "_requires_": "@set_timesteps" }, "sample": "$@inferer.sampling_fn(@noise, @autoencoder, @diffusion, @noise_scheduler)", "generated_image": "$@sample", "generated_image_np": "$@generated_image[0,0].cpu().numpy().transpose(1, 0)[::-1, ::-1]", "img_pil": "$Image.fromarray(visualize_2d_image(@generated_image_np), 'RGB')", "run": [ "$@create_output_dir", "$@img_pil.save(@output_dir+'/synimg_'+@output_postfix+'.png')" ] }