Spaces:
Sleeping
Sleeping
File size: 4,712 Bytes
c735a8f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from functools import partial
import os
import argparse
import yaml
import torch
import torchvision.transforms as transforms
import matplotlib.pyplot as plt
from guided_diffusion.condition_methods import get_conditioning_method
from guided_diffusion.measurements import get_noise, get_operator
from guided_diffusion.unet import create_model
from guided_diffusion.gaussian_diffusion import create_sampler
from data.dataloader import get_dataset, get_dataloader
from util.img_utils import clear_color, mask_generator
from util.logger import get_logger
def load_yaml(file_path: str) -> dict:
with open(file_path) as f:
config = yaml.load(f, Loader=yaml.FullLoader)
return config
def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_config', type=str)
parser.add_argument('--diffusion_config', type=str)
parser.add_argument('--task_config', type=str)
parser.add_argument('--gpu', type=int, default=0)
parser.add_argument('--save_dir', type=str, default='./results')
args = parser.parse_args()
# logger
logger = get_logger()
# Device setting
device_str = f"cuda:{args.gpu}" if torch.cuda.is_available() else 'cpu'
logger.info(f"Device set to {device_str}.")
device = torch.device(device_str)
# Load configurations
model_config = load_yaml(args.model_config)
diffusion_config = load_yaml(args.diffusion_config)
task_config = load_yaml(args.task_config)
#assert model_config['learn_sigma'] == diffusion_config['learn_sigma'], \
#"learn_sigma must be the same for model and diffusion configuartion."
# Load model
model = create_model(**model_config)
model = model.to(device)
model.eval()
# Prepare Operator and noise
measure_config = task_config['measurement']
operator = get_operator(device=device, **measure_config['operator'])
noiser = get_noise(**measure_config['noise'])
logger.info(f"Operation: {measure_config['operator']['name']} / Noise: {measure_config['noise']['name']}")
# Prepare conditioning method
cond_config = task_config['conditioning']
cond_method = get_conditioning_method(cond_config['method'], operator, noiser, **cond_config['params'])
measurement_cond_fn = cond_method.conditioning
logger.info(f"Conditioning method : {task_config['conditioning']['method']}")
# Load diffusion sampler
sampler = create_sampler(**diffusion_config)
sample_fn = partial(sampler.p_sample_loop, model=model, measurement_cond_fn=measurement_cond_fn)
# Working directory
out_path = os.path.join(args.save_dir, measure_config['operator']['name'])
os.makedirs(out_path, exist_ok=True)
for img_dir in ['input', 'recon', 'progress', 'label']:
os.makedirs(os.path.join(out_path, img_dir), exist_ok=True)
# Prepare dataloader
data_config = task_config['data']
transform = transforms.Compose([transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
dataset = get_dataset(**data_config, transforms=transform)
loader = get_dataloader(dataset, batch_size=1, num_workers=0, train=False)
# Exception) In case of inpainting, we need to generate a mask
if measure_config['operator']['name'] == 'inpainting':
mask_gen = mask_generator(
**measure_config['mask_opt']
)
# Do Inference
for i, ref_img in enumerate(loader):
logger.info(f"Inference for image {i}")
fname = str(i).zfill(5) + '.png'
ref_img = ref_img.to(device)
# Exception) In case of inpainging,
if measure_config['operator'] ['name'] == 'inpainting':
mask = mask_gen(ref_img)
mask = mask[:, 0, :, :].unsqueeze(dim=0)
measurement_cond_fn = partial(cond_method.conditioning, mask=mask)
sample_fn = partial(sample_fn, measurement_cond_fn=measurement_cond_fn)
# Forward measurement model (Ax + n)
y = operator.forward(ref_img, mask=mask)
y_n = noiser(y)
else:
# Forward measurement model (Ax + n)
y = operator.forward(ref_img)
y_n = noiser(y)
# Sampling
x_start = torch.randn(ref_img.shape, device=device).requires_grad_()
sample = sample_fn(x_start=x_start, measurement=y_n, record=True, save_root=out_path)
plt.imsave(os.path.join(out_path, 'input', fname), clear_color(y_n))
plt.imsave(os.path.join(out_path, 'label', fname), clear_color(ref_img))
plt.imsave(os.path.join(out_path, 'recon', fname), clear_color(sample))
if __name__ == '__main__':
main()
|