from functools import partial import os import argparse import yaml import torch import torchvision.transforms as transforms import matplotlib.pyplot as plt from guided_diffusion.condition_methods import get_conditioning_method from guided_diffusion.measurements import get_noise, get_operator from guided_diffusion.unet import create_model from guided_diffusion.gaussian_diffusion import create_sampler from data.dataloader import get_dataset, get_dataloader from util.img_utils import clear_color, mask_generator from util.logger import get_logger def load_yaml(file_path: str) -> dict: with open(file_path) as f: config = yaml.load(f, Loader=yaml.FullLoader) return config def main(): parser = argparse.ArgumentParser() parser.add_argument('--model_config', type=str) parser.add_argument('--diffusion_config', type=str) parser.add_argument('--task_config', type=str) parser.add_argument('--gpu', type=int, default=0) parser.add_argument('--save_dir', type=str, default='./results') args = parser.parse_args() # logger logger = get_logger() # Device setting device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu' logger.info(f"Device set to {device_str}.") device = torch.device(device_str) # Load configurations model_config = load_yaml(args.model_config) diffusion_config = load_yaml(args.diffusion_config) task_config = load_yaml(args.task_config) #assert model_config['learn_sigma'] == diffusion_config['learn_sigma'], \ #"learn_sigma must be the same for model and diffusion configuartion." # Load model model = create_model(**model_config) model = model.to(device) model.eval() # Prepare Operator and noise measure_config = task_config['measurement'] operator = get_operator(device=device, **measure_config['operator']) noiser = get_noise(**measure_config['noise']) logger.info(f"Operation: {measure_config['operator']['name']} / Noise: {measure_config['noise']['name']}") # Prepare conditioning method cond_config = task_config['conditioning'] cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params']) measurement_cond_fn = cond_method.conditioning logger.info(f"Conditioning method : {task_config['conditioning']['method']}") # Load diffusion sampler sampler = create_sampler(**diffusion_config) sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn) # Working directory out_path = os.path.join(args.save_dir, measure_config['operator']['name']) os.makedirs(out_path, exist_ok=True) for img_dir in ['input', 'recon', 'progress', 'label']: os.makedirs(os.path.join(out_path, img_dir), exist_ok=True) # Prepare dataloader data_config = task_config['data'] transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]) dataset = get_dataset(**data_config, transforms=transform) loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False) # Exception) In case of inpainting, we need to generate a mask if measure_config['operator']['name'] == 'inpainting': mask_gen = mask_generator( **measure_config['mask_opt'] ) # Do Inference for i, ref_img in enumerate(loader): logger.info(f"Inference for image {i}") fname = str(i).zfill(5) + '.png' ref_img = ref_img.to(device) # Exception) In case of inpainging, if measure_config['operator'] ['name'] == 'inpainting': mask = mask_gen(ref_img) mask = mask[:, 0, :, :].unsqueeze(dim=0) measurement_cond_fn = partial(cond_method.conditioning, mask=mask) sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn) # Forward measurement model (Ax + n) y = operator.forward(ref_img, mask=mask) y_n = noiser(y) else: # Forward measurement model (Ax + n) y = operator.forward(ref_img) y_n = noiser(y) # Sampling x_start = torch.randn(ref_img.shape, device=device).requires_grad_() sample = sample_fn(x_start=x_start, measurement=y_n, record=True, save_root=out_path) plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n)) plt.imsave(os.path.join(out_path, 'label', fname), clear_color(ref_img)) plt.imsave(os.path.join(out_path, 'recon', fname), clear_color(sample)) if __name__ == '__main__': main()