TextureScraping / swapae /evaluation /swap_generation_from_arranged_result_evaluator.py
sunshineatnoon
Add application file
1b2a9b1
raw
history blame
4.9 kB
import glob
import torchvision.transforms as transforms
import os
import torch
from swapae.evaluation import BaseEvaluator
import swapae.util as util
from PIL import Image
class InputDataset(torch.utils.data.Dataset):
def __init__(self, dataroot):
structure_images = sorted(glob.glob(os.path.join(dataroot, "input_structure", "*.png")))
style_images = sorted(glob.glob(os.path.join(dataroot, "input_style", "*.png")))
for structure_path, style_path in zip(structure_images, style_images):
assert structure_path.replace("structure", "style") == style_path, \
"%s and %s do not match" % (structure_path, style_path)
assert len(structure_images) == len(style_images)
print("found %d images at %s" % (len(structure_images), dataroot))
self.structure_images = structure_images
self.style_images = style_images
self.transform = transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
]
)
def __len__(self):
return len(self.structure_images)
def __getitem__(self, idx):
structure_image = self.transform(Image.open(self.structure_images[idx]).convert('RGB'))
style_image = self.transform(Image.open(self.style_images[idx]).convert('RGB'))
return {'structure': structure_image,
'style': style_image,
'path': self.structure_images[idx]}
class SwapGenerationFromArrangedResultEvaluator(BaseEvaluator):
""" Given two directories containing input structure and style (texture)
images, respectively, generate reconstructed and swapped images.
The input directories should contain the same set of image filenames.
It differs from StructureStyleGridGenerationEvaluator, which creates
N^2 outputs (i.e. swapping of all possible pairs between the structure and
style images).
"""
@staticmethod
def modify_commandline_options(parser, is_train):
return parser
def image_save_dir(self, nsteps):
return os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps), "images")
def create_webpage(self, nsteps):
if nsteps is None:
nsteps = self.opt.resume_iter
elif isinstance(nsteps, int):
nsteps = str(round(nsteps / 1000)) + "k"
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps))
os.makedirs(savedir, exist_ok=True)
webpage_title = "%s. iter=%s. phase=%s" % \
(self.opt.name, str(nsteps), self.target_phase)
self.webpage = util.HTML(savedir, webpage_title)
def add_to_webpage(self, images, filenames, tile=1):
converted_images = []
for image in images:
if isinstance(image, list):
image = torch.stack(image, dim=0).flatten(0, 1)
image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile)))
converted_images.append(image)
self.webpage.add_images(converted_images,
filenames)
print("saved %s" % str(filenames))
#self.webpage.save()
def set_num_test_images(self, num_images):
self.num_test_images = num_images
def evaluate(self, model, dataset, nsteps=None):
input_dataset = torch.utils.data.DataLoader(
InputDataset(self.opt.dataroot),
batch_size=1,
shuffle=False, drop_last=False, num_workers=0
)
self.num_test_images = None
self.create_webpage(nsteps)
image_num = 0
for i, data_i in enumerate(input_dataset):
structure = data_i["structure"].cuda()
style = data_i["style"].cuda()
path = data_i["path"][0]
path = os.path.basename(path)
#if "real_B" in data_i:
# image = torch.cat([image, data_i["real_B"].cuda()], dim=0)
# paths = paths + data_i["path_B"]
sp, gl = model(structure, command="encode")
rec = model(sp, gl, command="decode")
_, gl = model(style, command="encode")
swapped = model(sp, gl, command="decode")
self.add_to_webpage([structure, style, rec, swapped],
["%s_structure.png" % (path),
"%s_style.png" % (path),
"%s_rec.png" % (path),
"%s_swap.png" % (path)],
tile=1)
image_num += 1
if self.num_test_images is not None and self.num_test_images <= image_num:
self.webpage.save()
return {}
self.webpage.save()
return {}