|
import os |
|
import argparse |
|
import torch |
|
from pathlib import Path |
|
from AdaIN import AdaINNet |
|
from PIL import Image |
|
from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range |
|
import cv2 |
|
import imageio |
|
import numpy as np |
|
from tqdm import tqdm |
|
|
|
|
|
parser = argparse.ArgumentParser() |
|
parser.add_argument('--content_video', type=str, required=True, help='Content video file path') |
|
parser.add_argument('--style_image', type=str, required=True, help='Style image file path') |
|
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path') |
|
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level') |
|
parser.add_argument('--cuda', action='store_true', help='Use CUDA') |
|
parser.add_argument('--color_control', action='store_true', help='Preserve content color') |
|
args = parser.parse_args() |
|
|
|
device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0): |
|
""" |
|
Given content image and style image, generate feature maps with encoder, apply |
|
neural style transfer with adaptive instance normalization, generate output image |
|
with decoder |
|
|
|
Args: |
|
content_tensor (torch.FloatTensor): Content image |
|
style_tensor (torch.FloatTensor): Style Image |
|
encoder: Encoder (vgg19) network |
|
decoder: Decoder network |
|
alpha (float, default=1.0): Weight of style image feature |
|
|
|
Return: |
|
output_tensor (torch.FloatTensor): Style Transfer output image |
|
""" |
|
|
|
content_enc = encoder(content_tensor) |
|
style_enc = encoder(style_tensor) |
|
|
|
transfer_enc = adaptive_instance_normalization(content_enc, style_enc) |
|
|
|
mix_enc = alpha * transfer_enc + (1-alpha) * content_enc |
|
return decoder(mix_enc) |
|
|
|
|
|
def main(): |
|
|
|
content_video_pth = Path(args.content_video) |
|
content_video = cv2.VideoCapture(str(content_video_pth)) |
|
style_image_pth = Path(args.style_image) |
|
style_image = Image.open(style_image_pth) |
|
|
|
|
|
fps = int(content_video.get(cv2.CAP_PROP_FPS)) |
|
frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT)) |
|
video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT)) |
|
video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH)) |
|
|
|
|
|
video_tqdm = tqdm(frame_count) |
|
|
|
|
|
out_dir = './results_video/' |
|
os.makedirs(out_dir, exist_ok=True) |
|
out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem |
|
if args.color_control: out_pth += '_colorcontrol' |
|
out_pth += content_video_pth.suffix |
|
out_pth = Path(out_pth) |
|
writer = imageio.get_writer(out_pth, mode='I', fps=fps) |
|
|
|
|
|
vgg = torch.load('vgg_normalized.pth') |
|
model = AdaINNet(vgg).to(device) |
|
model.decoder.load_state_dict(torch.load(args.decoder_weight)) |
|
model.eval() |
|
|
|
t = transform(512) |
|
|
|
style_tensor = t(style_image).unsqueeze(0).to(device) |
|
|
|
|
|
|
|
while content_video.isOpened(): |
|
ret, content_image = content_video.read() |
|
|
|
if not ret: |
|
break |
|
|
|
content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device) |
|
|
|
|
|
if args.color_control: |
|
style_tensor = linear_histogram_matching(content_tensor,style_tensor) |
|
|
|
with torch.no_grad(): |
|
out_tensor = style_transfer(content_tensor, style_tensor, model.encoder |
|
, model.decoder, args.alpha).cpu().detach().numpy() |
|
|
|
|
|
out_tensor = np.squeeze(out_tensor, axis=0) |
|
out_tensor = np.transpose(out_tensor, (1, 2, 0)) |
|
out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U) |
|
out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC) |
|
|
|
|
|
writer.append_data(np.array(out_tensor)) |
|
video_tqdm.update(1) |
|
|
|
content_video.release() |
|
|
|
print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n') |
|
|
|
if __name__ == '__main__': |
|
main() |