Spaces:
Runtime error
Runtime error
import glob | |
import torchvision.transforms as transforms | |
import os | |
import torch | |
from swapae.evaluation import BaseEvaluator | |
import swapae.util as util | |
from PIL import Image | |
class InputDataset(torch.utils.data.Dataset): | |
def __init__(self, dataroot): | |
structure_images = sorted(glob.glob(os.path.join(dataroot, "input_structure", "*.png"))) | |
style_images = sorted(glob.glob(os.path.join(dataroot, "input_style", "*.png"))) | |
for structure_path, style_path in zip(structure_images, style_images): | |
assert structure_path.replace("structure", "style") == style_path, \ | |
"%s and %s do not match" % (structure_path, style_path) | |
assert len(structure_images) == len(style_images) | |
print("found %d images at %s" % (len(structure_images), dataroot)) | |
self.structure_images = structure_images | |
self.style_images = style_images | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToTensor(), | |
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) | |
] | |
) | |
def __len__(self): | |
return len(self.structure_images) | |
def __getitem__(self, idx): | |
structure_image = self.transform(Image.open(self.structure_images[idx]).convert('RGB')) | |
style_image = self.transform(Image.open(self.style_images[idx]).convert('RGB')) | |
return {'structure': structure_image, | |
'style': style_image, | |
'path': self.structure_images[idx]} | |
class SwapGenerationFromArrangedResultEvaluator(BaseEvaluator): | |
""" Given two directories containing input structure and style (texture) | |
images, respectively, generate reconstructed and swapped images. | |
The input directories should contain the same set of image filenames. | |
It differs from StructureStyleGridGenerationEvaluator, which creates | |
N^2 outputs (i.e. swapping of all possible pairs between the structure and | |
style images). | |
""" | |
def modify_commandline_options(parser, is_train): | |
return parser | |
def image_save_dir(self, nsteps): | |
return os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps), "images") | |
def create_webpage(self, nsteps): | |
if nsteps is None: | |
nsteps = self.opt.resume_iter | |
elif isinstance(nsteps, int): | |
nsteps = str(round(nsteps / 1000)) + "k" | |
savedir = os.path.join(self.output_dir(), "%s_%s" % (self.target_phase, nsteps)) | |
os.makedirs(savedir, exist_ok=True) | |
webpage_title = "%s. iter=%s. phase=%s" % \ | |
(self.opt.name, str(nsteps), self.target_phase) | |
self.webpage = util.HTML(savedir, webpage_title) | |
def add_to_webpage(self, images, filenames, tile=1): | |
converted_images = [] | |
for image in images: | |
if isinstance(image, list): | |
image = torch.stack(image, dim=0).flatten(0, 1) | |
image = Image.fromarray(util.tensor2im(image, tile=min(image.size(0), tile))) | |
converted_images.append(image) | |
self.webpage.add_images(converted_images, | |
filenames) | |
print("saved %s" % str(filenames)) | |
#self.webpage.save() | |
def set_num_test_images(self, num_images): | |
self.num_test_images = num_images | |
def evaluate(self, model, dataset, nsteps=None): | |
input_dataset = torch.utils.data.DataLoader( | |
InputDataset(self.opt.dataroot), | |
batch_size=1, | |
shuffle=False, drop_last=False, num_workers=0 | |
) | |
self.num_test_images = None | |
self.create_webpage(nsteps) | |
image_num = 0 | |
for i, data_i in enumerate(input_dataset): | |
structure = data_i["structure"].cuda() | |
style = data_i["style"].cuda() | |
path = data_i["path"][0] | |
path = os.path.basename(path) | |
#if "real_B" in data_i: | |
# image = torch.cat([image, data_i["real_B"].cuda()], dim=0) | |
# paths = paths + data_i["path_B"] | |
sp, gl = model(structure, command="encode") | |
rec = model(sp, gl, command="decode") | |
_, gl = model(style, command="encode") | |
swapped = model(sp, gl, command="decode") | |
self.add_to_webpage([structure, style, rec, swapped], | |
["%s_structure.png" % (path), | |
"%s_style.png" % (path), | |
"%s_rec.png" % (path), | |
"%s_swap.png" % (path)], | |
tile=1) | |
image_num += 1 | |
if self.num_test_images is not None and self.num_test_images <= image_num: | |
self.webpage.save() | |
return {} | |
self.webpage.save() | |
return {} | |