File size: 5,298 Bytes
7930ce0 f671c93 7930ce0 e4050d7 7930ce0 38886e4 7930ce0 4d4cef7 7930ce0 e4050d7 4f6c34a 38886e4 7930ce0 38886e4 7930ce0 38886e4 7930ce0 4f6c34a 7930ce0 4f6c34a e4050d7 7930ce0 38886e4 7930ce0 38886e4 7930ce0 38886e4 7930ce0 4d4cef7 4f6c34a 7930ce0 4f6c34a 7930ce0 e4050d7 7930ce0 4f6c34a 7930ce0 e4050d7 7930ce0 4f6c34a 7930ce0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import os
import argparse
import torch
import time
import numpy as np
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
from glob import glob
parser = argparse.ArgumentParser()
parser.add_argument('--content_image', type=str, help='Content image file path')
parser.add_argument('--content_dir', type=str, help='Content image folder path')
parser.add_argument('--style_image', type=str, help='Style image file path')
parser.add_argument('--style_dir', type=str, help='Content image folder path')
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
parser.add_argument('--output_dir', type=str, default="results")
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images')
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
args = parser.parse_args()
assert args.content_image or args.content_dir
assert args.style_image or args.style_dir
assert args.decoder_weight
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
"""
Given content image and style image, generate feature maps with encoder, apply
neural style transfer with adaptive instance normalization, generate output image
with decoder
Args:
content_tensor (torch.FloatTensor): Content image
style_tensor (torch.FloatTensor): Style Image
encoder: Encoder (vgg19) network
decoder: Decoder network
alpha (float, default=1.0): Weight of style image feature
Return:
output_tensor (torch.FloatTensor): Style Transfer output image
"""
content_enc = encoder(content_tensor)
style_enc = encoder(style_tensor)
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
return decoder(mix_enc)
def main():
# Read content images and style images
if args.content_image:
content_pths = [Path(args.content_image)]
else:
content_pths = [Path(f) for f in glob(args.content_dir+'/*')]
if args.style_image:
style_pths = [Path(args.style_image)]
else:
style_pths = [Path(f) for f in glob(args.style_dir+'/*')]
assert len(content_pths) > 0, 'Failed to load content image'
assert len(style_pths) > 0, 'Failed to load style image'
# Prepare directory for saving results
os.makedirs(args.output_dir, exist_ok=True)
# Load AdaIN model
vgg = torch.load('vgg_normalized.pth', weights_only=False)
model = AdaINNet(vgg).to(device)
model.decoder.load_state_dict(torch.load(args.decoder_weight, weights_only=False))
model.eval()
# Prepare image transform
t = transform(512)
# Prepare grid image, add style images to the first row
if args.grid_pth:
# Add empty image
imgs = [np.ones((1, 1, 3), np.uint8) * 255]
for style_pth in style_pths:
imgs.append(Image.open(style_pth))
# Timer
times = []
for content_pth in content_pths:
content_img = Image.open(content_pth)
if not content_img.mode == "RGB":
content_img = content_img.convert("RGB")
content_tensor = t(content_img).unsqueeze(0).to(device)
if args.grid_pth:
imgs.append(content_img)
for style_pth in style_pths:
# check if style transferred image exists already
out_pth = os.path.join(args.output_dir, content_pth.stem + '_style_' + style_pth.stem + '_alpha' + str(args.alpha) + content_pth.suffix)
if os.path.isfile(out_pth):
print("Skipping existing file")
continue
style_img = Image.open(style_pth)
if not style_img.mode == "RGB":
style_img = style_img.convert("RGB")
style_tensor = t(style_img).unsqueeze(0).to(device)
# Linear Histogram Matching if needed
if args.color_control:
style_tensor = linear_histogram_matching(content_tensor,style_tensor)
# Start time
tic = time.perf_counter()
# Execute style transfer
with torch.no_grad():
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha).cpu()
# End time
toc = time.perf_counter()
print("Content: " + content_pth.stem + ". Style: " \
+ style_pth.stem + '. Alpha: ' + str(args.alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
times.append(toc-tic)
# Save image
save_image(out_tensor, out_pth)
if args.grid_pth:
imgs.append(Image.open(out_pth))
# Remove runtime of first iteration because it is flawed for some unknown reason
if len(times) > 1:
times.pop(0)
avg = sum(times)/len(times)
print("Average style transfer time: %.4f seconds" % (avg))
# Generate grid image
if args.grid_pth:
print("Generating grid image")
grid_image(len(content_pths) + 1, len(style_pths) + 1, imgs, save_pth=args.grid_pth)
print("Finished")
if __name__ == '__main__':
main() |