PSLD / diffusion-posterior-sampling /bkse /domain_specific_deblur.py
LituRout's picture
add bkse
2f1d50b
raw
history blame contribute delete
No virus
3.18 kB
import argparse
from math import ceil, log10
from pathlib import Path
import torchvision
import yaml
from PIL import Image
from torch.nn import DataParallel
from torch.utils.data import DataLoader, Dataset
class Images(Dataset):
def __init__(self, root_dir, duplicates):
self.root_path = Path(root_dir)
self.image_list = list(self.root_path.glob("*.png"))
self.duplicates = (
duplicates # Number of times to duplicate the image in the dataset to produce multiple HR images
)
def __len__(self):
return self.duplicates * len(self.image_list)
def __getitem__(self, idx):
img_path = self.image_list[idx // self.duplicates]
image = torchvision.transforms.ToTensor()(Image.open(img_path))
if self.duplicates == 1:
return image, img_path.stem
else:
return image, img_path.stem + f"_{(idx % self.duplicates)+1}"
parser = argparse.ArgumentParser(description="PULSE")
# I/O arguments
parser.add_argument("--input_dir", type=str, default="imgs/blur_faces", help="input data directory")
parser.add_argument(
"--output_dir", type=str, default="experiments/domain_specific_deblur/results", help="output data directory"
)
parser.add_argument(
"--cache_dir",
type=str,
default="experiments/domain_specific_deblur/cache",
help="cache directory for model weights",
)
parser.add_argument(
"--yml_path", type=str, default="options/domain_specific_deblur/stylegan2.yml", help="configuration file"
)
kwargs = vars(parser.parse_args())
with open(kwargs["yml_path"], "rb") as f:
opt = yaml.safe_load(f)
dataset = Images(kwargs["input_dir"], duplicates=opt["duplicates"])
out_path = Path(kwargs["output_dir"])
out_path.mkdir(parents=True, exist_ok=True)
dataloader = DataLoader(dataset, batch_size=opt["batch_size"])
if opt["stylegan_ver"] == 1:
from models.dsd.dsd_stylegan import DSDStyleGAN
model = DSDStyleGAN(opt=opt, cache_dir=kwargs["cache_dir"])
else:
from models.dsd.dsd_stylegan2 import DSDStyleGAN2
model = DSDStyleGAN2(opt=opt, cache_dir=kwargs["cache_dir"])
model = DataParallel(model)
toPIL = torchvision.transforms.ToPILImage()
for ref_im, ref_im_name in dataloader:
if opt["save_intermediate"]:
padding = ceil(log10(100))
for i in range(opt["batch_size"]):
int_path_HR = Path(out_path / ref_im_name[i] / "HR")
int_path_LR = Path(out_path / ref_im_name[i] / "LR")
int_path_HR.mkdir(parents=True, exist_ok=True)
int_path_LR.mkdir(parents=True, exist_ok=True)
for j, (HR, LR) in enumerate(model(ref_im)):
for i in range(opt["batch_size"]):
toPIL(HR[i].cpu().detach().clamp(0, 1)).save(int_path_HR / f"{ref_im_name[i]}_{j:0{padding}}.png")
toPIL(LR[i].cpu().detach().clamp(0, 1)).save(int_path_LR / f"{ref_im_name[i]}_{j:0{padding}}.png")
else:
# out_im = model(ref_im,**kwargs)
for j, (HR, LR) in enumerate(model(ref_im)):
for i in range(opt["batch_size"]):
toPIL(HR[i].cpu().detach().clamp(0, 1)).save(out_path / f"{ref_im_name[i]}.png")