File size: 5,316 Bytes
7930ce0
 
 
 
 
 
 
 
 
3a1d3f5
7930ce0
 
 
 
 
e4050d7
7930ce0
 
 
 
 
3a1d3f5
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e4050d7
 
 
7930ce0
 
 
 
e4050d7
7930ce0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4f6c34a
7930ce0
 
 
3a1d3f5
 
 
 
 
 
 
7930ce0
 
4f6c34a
 
7930ce0
 
 
 
 
4f6c34a
3a1d3f5
 
 
7930ce0
 
 
 
 
4f6c34a
7930ce0
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import os
import argparse
import torch
import time
import numpy as np
from pathlib import Path
from AdaIN import AdaINNet
from PIL import Image
from torchvision.utils import save_image
from utils import adaptive_instance_normalization, transform,linear_histogram_matching, Range, grid_image
from glob import glob

parser = argparse.ArgumentParser()
parser.add_argument('--content_image', type=str, help='Test image file path')
parser.add_argument('--style_image', type=str, required=True, help='Multiple Style image file path, separated by comma')
parser.add_argument('--decoder_weight', type=str, default='decoder.pth', help='Decoder weight file path')
parser.add_argument('--alpha', type=float, default=1.0, choices=[Range(0.0, 1.0)], help='Alpha [0.0, 1.0] controls style transfer level')
parser.add_argument('--interpolation_weights', type=str, help='Weights of interpolate multiple style images')
parser.add_argument('--cuda', action='store_true', help='Use CUDA')
parser.add_argument('--grid_pth', type=str, default=None, help='Specify a grid image path (default=None) if generate a grid image that contains all style transferred images. \
	if use grid mode, provide 4 style images')
parser.add_argument('--color_control', action='store_true', help='Preserve content color')
args = parser.parse_args()
assert args.content_image
assert args.style_image
assert args.decoder_weight
assert args.interpolation_weights or args.grid_pth

device = torch.device('cuda' if args.cuda and torch.cuda.is_available() else 'cpu')


def interpolate_style_transfer(content_tensor, style_tensor, encoder, decoder, alpha=1.0, interpolation_weights=None):
	"""
	Given content image and multiple style images, generate feature maps with encoder, apply 
	neural style transfer with adaptive instance normalization, interpolate style image features 
	with interpolation weights, generate output image with decoder

	Args:
		content_tensor (torch.FloatTensor): Content image 
		style_tensor (torch.FloatTensor): Multiple Style Images
		encoder: Encoder (vgg19) network
		decoder: Decoder network
		alpha (float, default=1.0): Weight of style image feature 
		interpolation_weights (list): Weight of each style image 
	
	Return:
		output_tensor (torch.FloatTensor): Interpolate Style Transfer output image
	"""
	
	content_enc = encoder(content_tensor)
	style_enc = encoder(style_tensor)
	
	transfer_enc = torch.zeros_like(content_enc).to(device)
	full_enc = adaptive_instance_normalization(content_enc, style_enc)
	for i, w in enumerate(interpolation_weights):
		transfer_enc += w * full_enc[i]

	mix_enc = alpha * transfer_enc + (1-alpha) * content_enc
	return decoder(mix_enc)
	

def main():	
	# Read content and style image
	if args.content_image:
		content_pths = [Path(args.content_image)]
	else:
		content_pths = [Path(f) for f in glob(args.content_dir+'/*')]

	style_pths_list = args.style_image.split(',')
	style_pths = [Path(pth) for pth in style_pths_list]

	assert len(content_pths) > 0, 'Failed to load content image'
	assert len(style_pths) > 0, 'Failed to load style image'
	
	inter_weights = []
	# If grid mode, use 4 style images, 5x5 interpolation weights
	if args.grid_pth:
		assert len(style_pths) == 4, "Under grid mode, specify 4 style images"
		inter_weights = [ [ min(4-a, 4-b) / 4,  min(4-a, b) / 4, min(a, 4-b) / 4, min(a, b) / 4] \
			for a in range(5) for b in range(5) ]

	# Use user input interpolation weights
	else:
		inter_weight = [float(i) for i in args.interpolation_weights.split(',')]
		inter_weight = [i / sum(inter_weight) for i in inter_weight]
		inter_weights.append(inter_weight)
	

	out_dir = './results_interpolate/'
	os.makedirs(out_dir, exist_ok=True)
	
	# Load AdaIN model
	vgg = torch.load('vgg_normalized.pth')
	model = AdaINNet(vgg).to(device)
	model.decoder.load_state_dict(torch.load(args.decoder_weight))
	model.eval()
	
	# Prepare image transform
	t = transform(512)

	imgs = []

	for content_pth in content_pths:
		content_tensor = t(Image.open(content_pth)).unsqueeze(0).to(device)

		# Prepare multiple style images
		style_tensor = []
		for style_pth in style_pths:
			img = Image.open(style_pth)
			if args.color_control:
				img = transform([512,512])(img).unsqueeze(0)
				img = linear_histogram_matching(content_tensor,img)
				img = img.squeeze(0)
				style_tensor.append(img)
			else:
				style_tensor.append(transform([512, 512])(img))
		style_tensor = torch.stack(style_tensor, dim=0).to(device)
		
		for inter_weight in inter_weights:
			# Execute Interpolate style transfer			
			with torch.no_grad():
				out_tensor = out_tensor = interpolate_style_transfer(content_tensor, style_tensor, model.encoder, model.decoder, args.alpha, inter_weight).cpu()
			
			print("Content: " + content_pth.stem + ". Style: " + str([style_pth.stem for style_pth in style_pths]) + ". Interpolation weight: ", str(inter_weight))

			# Save results
			out_pth = out_dir + content_pth.stem + '_interpolate_' + str(inter_weight)
			if args.color_control: out_pth += '_colorcontrol'
			out_pth += content_pth.suffix
			save_image(out_tensor, out_pth)

			if args.grid_pth:
				imgs.append(Image.open(out_pth))

	# Generate grid image
	if args.grid_pth:
		print("Generating grid image")
		grid_image(5, 5, imgs, save_pth=args.grid_pth)
		print("Finished")

if __name__ == '__main__':
	main()