adain / test_api.py
tidalove's picture
Update test_api.py
a6e34e8 verified
raw
history blame
3.3 kB
import os
import tempfile
import torch
import time
import numpy as np
import random
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from torchvision.transforms import ToPILImage
from utils import adaptive_instance_normalization, grid_image, transform,linear_histogram_matching, Range
from glob import glob
from datasets import load_dataset
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
"""
Given content image and style image, generate feature maps with encoder, apply
neural style transfer with adaptive instance normalization, generate output image
with decoder
Args:
content_tensor (torch.FloatTensor): Content image
style_tensor (torch.FloatTensor): Style Image
encoder: Encoder (vgg19) network
decoder: Decoder network
alpha (float, default=1.0): Weight of style image feature
Return:
output_tensor (torch.FloatTensor): Style Transfer output image
"""
content_enc = encoder(content_tensor)
style_enc = encoder(style_tensor)
transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
return decoder(mix_enc)
def run_adain(content_dir, style_dataset_pth, out_dir, alpha=1.0, dataset_size=100, vgg_pth='vgg_normalized.pth', decoder_pth='decoder.pth'):
content_pths = [Path(f) for f in glob(content_dir+'/*')]
num_content_imgs = len(content_pths)
assert num_content_imgs > 0, 'Failed to load content image'
# Load AdaIN model
vgg = torch.load(vgg_pth)
model = AdaINNet(vgg).to(device)
model.decoder.load_state_dict(torch.load(decoder_pth))
model.eval()
# Prepare image transform
t = transform(512)
# Timer
times = []
style_ds = load_dataset(style_dataset_pth, split="train")
if num_content_imgs * len(style_ds) > dataset_size:
num_style_per_content = int(np.ceil(dataset_size / num_content_imgs))
else:
num_style_per_content = len(style_ds)
for content_pth in content_pths:
content_img = Image.open(content_pth)
content_tensor = t(content_img).unsqueeze(0).to(device)
indices = random.sample(range(len(style_ds)), num_style_per_content)
for idx in indices:
style_img = style_ds[idx]['image']
if style_img.mode not in ("RGB", "L"):
style_img = style_img.convert("RGB")
style_tensor = t(style_img).unsqueeze(0).to(device)
# Start time
tic = time.perf_counter()
# Execute style transfer
with torch.no_grad():
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu()
# End time
toc = time.perf_counter()
print("Content: " + content_pth.stem + ". Style: " \
+ str(style_idx) + '. Alpha: ' + str(alpha) + '. Style Transfer time: %.4f seconds' % (toc-tic))
times.append(toc-tic)
# Save image
out_pth = out_dir + content_pth.stem + '_style_' + str(idx) + '_alpha' + str(alpha)
out_pth += content_pth.suffix
save_image(out_tensor, out_pth)
# Remove runtime of first iteration because it is flawed for some unknown reason
if len(times) > 1:
times.pop(0)
avg = sum(times)/len(times)
print("Average style transfer time: %.4f seconds" % (avg))