SofaStyler / test.py
Sophie98
test commit
ab92204
raw history blame
No virus
5.76 kB
import argparse
from pathlib import Path
import os
import torch
import torch.nn as nn
from PIL import Image
from os.path import basename
from os.path import splitext
from torchvision import transforms
from torchvision.utils import save_image
from function import calc_mean_std, normal, coral
import transformer as transformer
import StyTR as StyTR
import matplotlib.pyplot as plt
from matplotlib import cm
from function import normal
import numpy as np
def test_transform(size, crop):
transform_list = []
if size != 0:
transform_list.append(transforms.Resize(size))
if crop:
transform_list.append(transforms.CenterCrop(size))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def style_transform(h,w):
k = (h,w)
size = int(np.max(k))
transform_list = []
transform_list.append(transforms.CenterCrop((h,w)))
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
def content_transform():
transform_list = []
transform_list.append(transforms.ToTensor())
transform = transforms.Compose(transform_list)
return transform
parser = argparse.ArgumentParser()
# Basic options
parser.add_argument('--content', type=str,
help='File path to the content image')
parser.add_argument('--content_dir', type=str,
help='Directory path to a batch of content images')
parser.add_argument('--style', type=str,
help='File path to the style image, or multiple style \
images separated by commas if you want to do style \
interpolation or spatial control')
parser.add_argument('--style_dir', type=str,
help='Directory path to a batch of style images')
parser.add_argument('--output', type=str, default='output',
help='Directory to save the output image(s)')
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
parser.add_argument('--style_interpolation_weights', type=str, default="")
parser.add_argument('--a', type=float, default=1.0)
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
help="Type of positional embedding to use on top of the image features")
parser.add_argument('--hidden_dim', default=512, type=int,
help="Size of the embeddings (dimension of the transformer)")
args = parser.parse_args()
# Advanced options
content_size=640
style_size=640
crop='store_true'
save_ext='.jpg'
output_path=args.output
preserve_color='store_true'
alpha=args.a
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Either --content or --content_dir should be given.
if args.content:
content_paths = [Path(args.content)]
else:
content_dir = Path(args.content_dir)
content_paths = [f for f in content_dir.glob('*')]
# Either --style or --style_dir should be given.
if args.style:
style_paths = [Path(args.style)]
else:
style_dir = Path(args.style_dir)
style_paths = [f for f in style_dir.glob('*')]
if not os.path.exists(output_path):
os.mkdir(output_path)
vgg = StyTR.vgg
vgg.load_state_dict(torch.load(args.vgg))
vgg = nn.Sequential(*list(vgg.children())[:44])
decoder = StyTR.decoder
Trans = transformer.Transformer()
embedding = StyTR.PatchEmbed()
decoder.eval()
Trans.eval()
vgg.eval()
from collections import OrderedDict
new_state_dict = OrderedDict()
state_dict = torch.load(args.decoder_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
decoder.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(args.Trans_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
Trans.load_state_dict(new_state_dict)
new_state_dict = OrderedDict()
state_dict = torch.load(args.embedding_path)
for k, v in state_dict.items():
#namekey = k[7:] # remove `module.`
namekey = k
new_state_dict[namekey] = v
embedding.load_state_dict(new_state_dict)
network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
network.eval()
network.to(device)
content_tf = test_transform(content_size, crop)
style_tf = test_transform(style_size, crop)
for content_path in content_paths:
for style_path in style_paths:
print(content_path)
content_tf1 = content_transform()
content = content_tf(Image.open(content_path).convert("RGB"))
h,w,c=np.shape(content)
style_tf1 = style_transform(h,w)
style = style_tf(Image.open(style_path).convert("RGB"))
style = style.to(device).unsqueeze(0)
content = content.to(device).unsqueeze(0)
with torch.no_grad():
output= network(content,style)
output = output[0].cpu()
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
output_path, splitext(basename(content_path))[0],
splitext(basename(style_path))[0], save_ext
)
save_image(output, output_name)