File size: 4,094 Bytes
7930ce0
 
 
 
 
 
cfa18c8
7930ce0
 
 
 
 
 
 
 
 
e4050d7
7930ce0
 
cfa18c8
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6c34a
7930ce0
 
 
 
 
4f6c34a
7930ce0
 
4f6c34a
7930ce0
 
cfa18c8
 
 
 
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
cfa18c8
7930ce0
 
4f6c34a
 
7930ce0
 
 
 
cfa18c8
 
 
 
7930ce0
 
 
 
 
 
 
 
 
 
4f6c34a
7930ce0
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
import argparse
import torch
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from utils import transform, adaptive_instance_normalization,linear_histogram_matching, Range
import cv2
import imageio
import numpy as np
from tqdm import tqdm


parser = argparse.ArgumentParser()
parser.add_argument('--content_video', type=str, required=True, help='Content video file path')
parser.add_argument('--style_image', type=str, required=True, help='Style image file path')
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
args = parser.parse_args()

device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')


def style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0):
	"""
	Given content image and style image, generate feature maps with encoder, apply 
	neural style transfer with adaptive instance normalization, generate output image
	with decoder

	Args:
		content_tensor (torch.FloatTensor): Content image 
		style_tensor (torch.FloatTensor): Style Image
		encoder: Encoder (vgg19) network
		decoder: Decoder network
		alpha (float, default=1.0): Weight of style image feature 
	
	Return:
		output_tensor (torch.FloatTensor): Style Transfer output image
	"""

	content_enc = encoder(content_tensor)
	style_enc = encoder(style_tensor)

	transfer_enc = adaptive_instance_normalization(content_enc, style_enc)
	
	mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
	return decoder(mix_enc)


def main():
	# Read video file
	content_video_pth = Path(args.content_video)
	content_video = cv2.VideoCapture(str(content_video_pth))
	style_image_pth = Path(args.style_image)
	style_image = Image.open(style_image_pth)

	# Read video info
	fps = int(content_video.get(cv2.CAP_PROP_FPS))
	frame_count = int(content_video.get(cv2.CAP_PROP_FRAME_COUNT))
	video_height = int(content_video.get(cv2.CAP_PROP_FRAME_HEIGHT))
	video_width = int(content_video.get(cv2.CAP_PROP_FRAME_WIDTH))

	# Prepare loop
	video_tqdm = tqdm(frame_count)

	# Prepare output video writer
	out_dir = './results_video/'
	os.makedirs(out_dir, exist_ok=True)
	out_pth = out_dir + content_video_pth.stem + '_style_' + style_image_pth.stem
	if args.color_control: out_pth += '_colorcontrol'
	out_pth += content_video_pth.suffix
	out_pth = Path(out_pth)
	writer = imageio.get_writer(out_pth, mode='I', fps=fps)

	# Load AdaIN model
	vgg = torch.load('vgg_normalized.pth')
	model = AdaINNet(vgg).to(device)
	model.decoder.load_state_dict(torch.load(args.decoder_weight))
	model.eval()
	
	t = transform(512)

	style_tensor = t(style_image).unsqueeze(0).to(device)



	while content_video.isOpened():
		ret, content_image = content_video.read()
		# Failed to read a frame
		if not ret:
			break
		
		content_tensor = t(Image.fromarray(content_image)).unsqueeze(0).to(device)
		
		# Linear Histogram Matching if needed
		if args.color_control:
			style_tensor = linear_histogram_matching(content_tensor,style_tensor)

		with torch.no_grad():
			out_tensor = style_transfer(content_tensor, style_tensor, model.encoder
				, model.decoder, args.alpha).cpu().detach().numpy()
		
		# Convert output frame to original size and rgb range (0,255)
		out_tensor = np.squeeze(out_tensor, axis=0)
		out_tensor = np.transpose(out_tensor, (1, 2, 0))
		out_tensor = cv2.normalize(src=out_tensor, dst=None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX, dtype=cv2.CV_8U)
		out_tensor = cv2.resize(out_tensor, (video_width, video_height), interpolation=cv2.INTER_CUBIC)

		# Write output frame to video
		writer.append_data(np.array(out_tensor))
		video_tqdm.update(1)

	content_video.release()

	print("\nContent: " + content_video_pth.stem + ". Style: " + style_image_pth.stem +'\n')

if __name__ == '__main__':
	main()