import os import argparse import torch import time import numpy as np from pathlib import Path from AdaIN import AdaINNet from PIL import Image from torchvision.utils import save_image from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image from glob import glob parser = argparse.ArgumentParser() parser.add_argument('--content_image', type=str, help='Test image file path') parser.add_argument('--style_image', type=str, required=True, help='Multiple Style image file path, separated by comma') parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path') parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level') parser.add_argument('--interpolation_weights', type=str, help='Weights of interpolate multiple style images') parser.add_argument('--cuda', action='store_true', help='Use CUDA') parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \ if use grid mode, provide 4 style images') parser.add_argument('--color_control', action='store_true', help='Preserve content color') args = parser.parse_args() assert args.content_image assert args.style_image assert args.decoder_weight assert args.interpolation_weights or args.grid_pth device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') def interpolate_style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0, interpolation_weights=None): """ Given content image and multiple style images, generate feature maps with encoder, apply neural style transfer with adaptive instance normalization, interpolate style image features with interpolation weights, generate output image with decoder Args: content_tensor (torch.FloatTensor): Content image style_tensor (torch.FloatTensor): Multiple Style Images encoder: Encoder (vgg19) network decoder: Decoder network alpha (float, default=1.0): Weight of style image feature interpolation_weights (list): Weight of each style image Return: output_tensor (torch.FloatTensor): Interpolate Style Transfer output image """ content_enc = encoder(content_tensor) style_enc = encoder(style_tensor) transfer_enc = torch.zeros_like(content_enc).to(device) full_enc = adaptive_instance_normalization(content_enc, style_enc) for i, w in enumerate(interpolation_weights): transfer_enc += w * full_enc[i] mix_enc = alpha * transfer_enc + (1-alpha) * content_enc return decoder(mix_enc) def main(): # Read content and style image if args.content_image: content_pths = [Path(args.content_image)] else: content_pths = [Path(f) for f in glob(args.content_dir+'/*')] style_pths_list = args.style_image.split(',') style_pths = [Path(pth) for pth in style_pths_list] assert len(content_pths) > 0, 'Failed to load content image' assert len(style_pths) > 0, 'Failed to load style image' inter_weights = [] # If grid mode, use 4 style images, 5x5 interpolation weights if args.grid_pth: assert len(style_pths) == 4, "Under grid mode, specify 4 style images" inter_weights = [ [ min(4-a, 4-b) / 4, min(4-a, b) / 4, min(a, 4-b) / 4, min(a, b) / 4] \ for a in range(5) for b in range(5) ] # Use user input interpolation weights else: inter_weight = [float(i) for i in args.interpolation_weights.split(',')] inter_weight = [i / sum(inter_weight) for i in inter_weight] inter_weights.append(inter_weight) out_dir = './results_interpolate/' os.makedirs(out_dir, exist_ok=True) # Load AdaIN model vgg = torch.load('vgg_normalized.pth') model = AdaINNet(vgg).to(device) model.decoder.load_state_dict(torch.load(args.decoder_weight)) model.eval() # Prepare image transform t = transform(512) imgs = [] for content_pth in content_pths: content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device) # Prepare multiple style images style_tensor = [] for style_pth in style_pths: img = Image.open(style_pth) if args.color_control: img = transform([512,512])(img).unsqueeze(0) img = linear_histogram_matching(content_tensor,img) img = img.squeeze(0) style_tensor.append(img) else: style_tensor.append(transform([512, 512])(img)) style_tensor = torch.stack(style_tensor, dim=0).to(device) for inter_weight in inter_weights: # Execute Interpolate style transfer with torch.no_grad(): out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu() print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight)) # Save results out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight) if args.color_control: out_pth += '_colorcontrol' out_pth += content_pth.suffix save_image(out_tensor, out_pth) if args.grid_pth: imgs.append(Image.open(out_pth)) # Generate grid image if args.grid_pth: print("Generating grid image") grid_image(5, 5, imgs, save_pth=args.grid_pth) print("Finished") if __name__ == '__main__': main()