Spaces:
Runtime error
Runtime error
import os | |
import torchvision.transforms as transforms | |
from PIL import Image | |
from swapae.evaluation import BaseEvaluator | |
from swapae.data.base_dataset import get_transform | |
import swapae.util as util | |
class SimpleSwappingEvaluator(BaseEvaluator): | |
def modify_commandline_options(parser, is_train): | |
parser.add_argument("--input_structure_image", required=True, type=str) | |
parser.add_argument("--input_texture_image", required=True, type=str) | |
parser.add_argument("--texture_mix_alphas", type=float, nargs='+', | |
default=[1.0], | |
help="Performs interpolation of the texture image." | |
"If set to 1.0, it performs full swapping." | |
"If set to 0.0, it performs direct reconstruction" | |
) | |
opt, _ = parser.parse_known_args() | |
dataroot = os.path.dirname(opt.input_structure_image) | |
# dataroot and dataset_mode are ignored in SimpleSwapplingEvaluator. | |
# Just set it to the directory that contains the input structure image. | |
parser.set_defaults(dataroot=dataroot, dataset_mode="imagefolder") | |
return parser | |
def load_image(self, path): | |
path = os.path.expanduser(path) | |
img = Image.open(path).convert('RGB') | |
transform = get_transform(self.opt) | |
tensor = transform(img).unsqueeze(0) | |
return tensor | |
def evaluate(self, model, dataset, nsteps=None): | |
structure_image = self.load_image(self.opt.input_structure_image) | |
texture_image = self.load_image(self.opt.input_texture_image) | |
os.makedirs(self.output_dir(), exist_ok=True) | |
model(sample_image=structure_image, command="fix_noise") | |
structure_code, source_texture_code = model( | |
structure_image, command="encode") | |
_, target_texture_code = model(texture_image, command="encode") | |
alphas = self.opt.texture_mix_alphas | |
for alpha in alphas: | |
texture_code = util.lerp( | |
source_texture_code, target_texture_code, alpha) | |
output_image = model(structure_code, texture_code, command="decode") | |
output_image = transforms.ToPILImage()( | |
(output_image[0].clamp(-1.0, 1.0) + 1.0) * 0.5) | |
output_name = "%s_%s_%.2f.png" % ( | |
os.path.splitext(os.path.basename(self.opt.input_structure_image))[0], | |
os.path.splitext(os.path.basename(self.opt.input_texture_image))[0], | |
alpha | |
) | |
output_path = os.path.join(self.output_dir(), output_name) | |
output_image.save(output_path) | |
print("Saved at " + output_path) | |
return {} | |