|
import torch |
|
import torchvision.transforms |
|
from PIL import Image |
|
|
|
from AdaIN import AdaINNet |
|
from utils import adaptive_instance_normalization, transform, linear_histogram_matching |
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): |
|
""" |
|
Given content image and style image, generate feature maps with encoder, apply |
|
neural style transfer with adaptive instance normalization, generate output image |
|
with decoder |
|
|
|
Args: |
|
content_tensor (torch.FloatTensor): Content image |
|
style_tensor (torch.FloatTensor): Style Image |
|
encoder: Encoder (vgg19) network |
|
decoder: Decoder network |
|
alpha (float, default=1.0): Weight of style image feature |
|
|
|
Return: |
|
output_tensor (torch.FloatTensor): Style Transfer output image |
|
""" |
|
|
|
content_enc = encoder(content_tensor) |
|
style_enc = encoder(style_tensor) |
|
|
|
transfer_enc = adaptive_instance_normalization(content_enc, style_enc) |
|
|
|
mix_enc = alpha * transfer_enc + (1 - alpha) * content_enc |
|
return decoder(mix_enc) |
|
|
|
|
|
def convert(content_path, style_path, vgg_weights_path, decoder_weights_path, alpha, color_control): |
|
|
|
vgg = torch.load(vgg_weights_path) |
|
model = AdaINNet(vgg).to(device) |
|
model.decoder.load_state_dict(torch.load(decoder_weights_path)) |
|
model.eval() |
|
|
|
|
|
t = transform(512) |
|
|
|
|
|
content_img = Image.open(content_path) |
|
content_tensor = t(content_img).unsqueeze(0).to(device) |
|
style_tensor = t(Image.open(style_path)).unsqueeze(0).to(device) |
|
|
|
if color_control: |
|
style_tensor = linear_histogram_matching(content_tensor, style_tensor) |
|
|
|
with torch.no_grad(): |
|
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, alpha).cpu() |
|
|
|
outimage_fname = 'output.png' |
|
torchvision.utils.save_image(out_tensor.squeeze(0), outimage_fname) |
|
|
|
return outimage_fname |
|
|