Spaces:
Build error
Build error
Sophie98
commited on
Commit
·
e4fb230
1
Parent(s):
79c6687
restructured code
Browse files- .gitignore +8 -0
- StyTR.py +1 -1
- app.py +72 -2
- segmentation.py +6 -8
- styleTransfer.py +108 -71
- test.py +0 -175
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/box_ops.cpython-37.pyc
|
2 |
+
__pycache__/function.cpython-37.pyc
|
3 |
+
__pycache__/misc.cpython-37.pyc
|
4 |
+
__pycache__/segmentation.cpython-37.pyc
|
5 |
+
__pycache__/styleTransfer.cpython-37.pyc
|
6 |
+
__pycache__/StyTR.cpython-37.pyc
|
7 |
+
__pycache__/transformer.cpython-37.pyc
|
8 |
+
__pycache__/ViT_helper.cpython-37.pyc
|
StyTR.py
CHANGED
@@ -137,7 +137,7 @@ class MLP(nn.Module):
|
|
137 |
class StyTrans(nn.Module):
|
138 |
""" This is the style transform transformer module """
|
139 |
|
140 |
-
def __init__(self,encoder,decoder,PatchEmbed, transformer
|
141 |
|
142 |
super().__init__()
|
143 |
enc_layers = list(encoder.children())
|
|
|
137 |
class StyTrans(nn.Module):
|
138 |
""" This is the style transform transformer module """
|
139 |
|
140 |
+
def __init__(self,encoder,decoder,PatchEmbed, transformer):
|
141 |
|
142 |
super().__init__()
|
143 |
enc_layers = list(encoder.children())
|
app.py
CHANGED
@@ -1,9 +1,78 @@
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
from segmentation import get_mask,replace_sofa
|
4 |
-
from styleTransfer import
|
5 |
from PIL import Image
|
6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
8 |
"""
|
9 |
Styles (all) the sofas in the image to the given style.
|
@@ -17,6 +86,7 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
|
17 |
"""
|
18 |
|
19 |
# preprocess input images to be (640,640) squares to fit requirements of the segmentation model
|
|
|
20 |
resized_img,box = resize_sofa(input_img)
|
21 |
resized_style = resize_style(style_img)
|
22 |
# generate mask for image
|
@@ -24,7 +94,7 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
|
24 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
25 |
# postprocess the final image
|
26 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
27 |
-
new_sofa =
|
28 |
return new_sofa
|
29 |
|
30 |
image = gr.inputs.Image()
|
|
|
1 |
import numpy as np
|
2 |
import gradio as gr
|
3 |
from segmentation import get_mask,replace_sofa
|
4 |
+
from styleTransfer import create_styledSofa
|
5 |
from PIL import Image
|
6 |
|
7 |
+
def resize_sofa(img):
|
8 |
+
"""
|
9 |
+
This function adds padding to make the orignal image square and 640by640.
|
10 |
+
It also returns the orignal ratio of the image, such that it can be reverted later.
|
11 |
+
Parameters:
|
12 |
+
img = original image
|
13 |
+
Return:
|
14 |
+
im1 = squared image
|
15 |
+
box = parameters to later crop the image to it original ratio
|
16 |
+
"""
|
17 |
+
width, height = img.size
|
18 |
+
idx = np.argmin([width,height])
|
19 |
+
newsize = (640, 640) # parameters from test script
|
20 |
+
|
21 |
+
if idx==0:
|
22 |
+
img1 = Image.new(img.mode, (height, height), (255, 255, 255))
|
23 |
+
img1.paste(img, ((height-width)//2, 0))
|
24 |
+
box = ( newsize[0]*(1-width/height)//2,
|
25 |
+
0,
|
26 |
+
newsize[0]-newsize[0]*(1-width/height)//2,
|
27 |
+
newsize[1])
|
28 |
+
else:
|
29 |
+
img1 = Image.new(img.mode, (width, width), (255, 255, 255))
|
30 |
+
img1.paste(img, (0, (width-height)//2))
|
31 |
+
box = (0,
|
32 |
+
newsize[1]*(1-height/width)//2,
|
33 |
+
newsize[0],
|
34 |
+
newsize[1]-newsize[1]*(1-height/width)//2)
|
35 |
+
im1 = img1.resize(newsize)
|
36 |
+
return im1,box
|
37 |
+
|
38 |
+
def resize_style(img):
|
39 |
+
"""
|
40 |
+
This function generates a zoomed out version of the style image and resizes it to a 640by640 square.
|
41 |
+
Parameters:
|
42 |
+
img = image containing the style/pattern
|
43 |
+
Return:
|
44 |
+
dst = a zoomed-out and resized version of the pattern
|
45 |
+
"""
|
46 |
+
width, height = img.size
|
47 |
+
idx = np.argmin([width,height])
|
48 |
+
|
49 |
+
# Makes the image square by cropping
|
50 |
+
if idx==0:
|
51 |
+
top= (height-width)//2
|
52 |
+
bottom= height-(height-width)//2
|
53 |
+
left = 0
|
54 |
+
right= width
|
55 |
+
else:
|
56 |
+
left = (width-height)//2
|
57 |
+
right = width - (width-height)//2
|
58 |
+
top = 0
|
59 |
+
bottom = height
|
60 |
+
newsize = (640, 640) # parameters from test script
|
61 |
+
im1 = img.crop((left, top, right, bottom))
|
62 |
+
|
63 |
+
# Constructs a zoomed-out version
|
64 |
+
copies = 8
|
65 |
+
resize = (newsize[0]//copies,newsize[1]//copies)
|
66 |
+
dst = Image.new('RGB', (resize[0]*copies,resize[1]*copies))
|
67 |
+
im2 = im1.resize((resize))
|
68 |
+
for row in range(copies):
|
69 |
+
im2 = im2.transpose(Image.FLIP_LEFT_RIGHT)
|
70 |
+
for column in range(copies):
|
71 |
+
im2 = im2.transpose(Image.FLIP_TOP_BOTTOM)
|
72 |
+
dst.paste(im2, (resize[0]*row, resize[1]*column))
|
73 |
+
dst = dst.resize((newsize))
|
74 |
+
return dst
|
75 |
+
|
76 |
def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
77 |
"""
|
78 |
Styles (all) the sofas in the image to the given style.
|
|
|
86 |
"""
|
87 |
|
88 |
# preprocess input images to be (640,640) squares to fit requirements of the segmentation model
|
89 |
+
input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
|
90 |
resized_img,box = resize_sofa(input_img)
|
91 |
resized_style = resize_style(style_img)
|
92 |
# generate mask for image
|
|
|
94 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
95 |
# postprocess the final image
|
96 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
97 |
+
new_sofa = new_sofa.crop(box)
|
98 |
return new_sofa
|
99 |
|
100 |
image = gr.inputs.Image()
|
segmentation.py
CHANGED
@@ -7,7 +7,7 @@ import matplotlib.pyplot as plt
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
10 |
-
def get_mask(image):
|
11 |
"""
|
12 |
This function generates a mask of the image that highlights all the sofas in the image.
|
13 |
This uses a pre-trained Unet model with a resnet50 backbone.
|
@@ -50,13 +50,11 @@ def get_mask(image):
|
|
50 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
51 |
test_img = np.expand_dims(test_img, axis=0)
|
52 |
|
53 |
-
prediction = model.predict(preprocess_input(test_img)).round()
|
54 |
-
print("generated mask")
|
55 |
mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
|
56 |
-
|
57 |
-
return np.array(mask)
|
58 |
|
59 |
-
def replace_sofa(image,mask,styled_sofa):
|
60 |
"""
|
61 |
This function replaces the original sofa in the image by the new styled sofa according
|
62 |
to the mask.
|
@@ -68,7 +66,7 @@ def replace_sofa(image,mask,styled_sofa):
|
|
68 |
Return:
|
69 |
new_image = Image containing the styled sofa
|
70 |
"""
|
71 |
-
image = np.array(image)
|
72 |
#image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
73 |
styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
|
74 |
|
@@ -77,7 +75,7 @@ def replace_sofa(image,mask,styled_sofa):
|
|
77 |
image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
|
78 |
sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
|
79 |
new_image = cv2.add(image_bg,sofa_fg)
|
80 |
-
return new_image
|
81 |
|
82 |
# image = cv2.imread('input/sofa.jpg')
|
83 |
# mask = cv2.imread('masks/sofa.jpg')
|
|
|
7 |
from PIL import Image
|
8 |
import segmentation_models as sm
|
9 |
|
10 |
+
def get_mask(image:Image) -> Image:
|
11 |
"""
|
12 |
This function generates a mask of the image that highlights all the sofas in the image.
|
13 |
This uses a pre-trained Unet model with a resnet50 backbone.
|
|
|
50 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
51 |
test_img = np.expand_dims(test_img, axis=0)
|
52 |
|
53 |
+
prediction = model.predict(preprocess_input(np.array(test_img))).round()
|
|
|
54 |
mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
|
55 |
+
return mask
|
|
|
56 |
|
57 |
+
def replace_sofa(image:Image, mask:Image, styled_sofa:Image) -> Image:
|
58 |
"""
|
59 |
This function replaces the original sofa in the image by the new styled sofa according
|
60 |
to the mask.
|
|
|
66 |
Return:
|
67 |
new_image = Image containing the styled sofa
|
68 |
"""
|
69 |
+
image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa)
|
70 |
#image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
71 |
styled_sofa = cv2.cvtColor(styled_sofa, cv2.COLOR_BGR2RGB)
|
72 |
|
|
|
75 |
image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
|
76 |
sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
|
77 |
new_image = cv2.add(image_bg,sofa_fg)
|
78 |
+
return Image.fromarray(new_image)
|
79 |
|
80 |
# image = cv2.imread('input/sofa.jpg')
|
81 |
# mask = cv2.imread('masks/sofa.jpg')
|
styleTransfer.py
CHANGED
@@ -1,77 +1,114 @@
|
|
1 |
from PIL import Image
|
2 |
import numpy as np
|
3 |
-
import
|
4 |
-
import
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
return styled_sofa
|
76 |
|
77 |
# image = Image.open('sofa_office.jpg')
|
|
|
1 |
from PIL import Image
|
2 |
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from PIL import Image
|
6 |
+
from torchvision import transforms
|
7 |
+
import transformer as transformer
|
8 |
+
import StyTR as StyTR
|
9 |
+
import numpy as np
|
10 |
+
from collections import OrderedDict
|
11 |
+
|
12 |
+
def test_transform(size, crop):
|
13 |
+
transform_list = []
|
14 |
+
|
15 |
+
if size != 0:
|
16 |
+
transform_list.append(transforms.Resize(size))
|
17 |
+
if crop:
|
18 |
+
transform_list.append(transforms.CenterCrop(size))
|
19 |
+
transform_list.append(transforms.ToTensor())
|
20 |
+
transform = transforms.Compose(transform_list)
|
21 |
+
return transform
|
22 |
+
def style_transform(h,w):
|
23 |
+
k = (h,w)
|
24 |
+
size = int(np.max(k))
|
25 |
+
transform_list = []
|
26 |
+
transform_list.append(transforms.CenterCrop((h,w)))
|
27 |
+
transform_list.append(transforms.ToTensor())
|
28 |
+
transform = transforms.Compose(transform_list)
|
29 |
+
return transform
|
30 |
+
|
31 |
+
def content_transform():
|
32 |
+
transform_list = []
|
33 |
+
transform_list.append(transforms.ToTensor())
|
34 |
+
transform = transforms.Compose(transform_list)
|
35 |
+
return transform
|
36 |
+
|
37 |
+
def Transformer(content_img: Image, style_img: Image,
|
38 |
+
vgg_path:str = 'vgg_normalised.pth', decoder_path:str = 'decoder_iter_160000.pth',
|
39 |
+
Trans_path:str = 'transformer_iter_160000.pth', embedding_path:str = 'embedding_iter_160000.pth'):
|
40 |
+
# Advanced options
|
41 |
+
content_size=640
|
42 |
+
style_size=640
|
43 |
+
crop='store_true'
|
44 |
+
|
45 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
+
|
47 |
+
vgg = StyTR.vgg
|
48 |
+
vgg.load_state_dict(torch.load(vgg_path))
|
49 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
50 |
+
|
51 |
+
decoder = StyTR.decoder
|
52 |
+
Trans = transformer.Transformer()
|
53 |
+
embedding = StyTR.PatchEmbed()
|
54 |
+
|
55 |
+
decoder.eval()
|
56 |
+
Trans.eval()
|
57 |
+
vgg.eval()
|
58 |
+
|
59 |
+
new_state_dict = OrderedDict()
|
60 |
+
state_dict = torch.load(decoder_path)
|
61 |
+
for k, v in state_dict.items():
|
62 |
+
#namekey = k[7:] # remove `module.`
|
63 |
+
namekey = k
|
64 |
+
new_state_dict[namekey] = v
|
65 |
+
decoder.load_state_dict(new_state_dict)
|
66 |
+
|
67 |
+
new_state_dict = OrderedDict()
|
68 |
+
state_dict = torch.load(Trans_path)
|
69 |
+
for k, v in state_dict.items():
|
70 |
+
#namekey = k[7:] # remove `module.`
|
71 |
+
namekey = k
|
72 |
+
new_state_dict[namekey] = v
|
73 |
+
Trans.load_state_dict(new_state_dict)
|
74 |
+
|
75 |
+
new_state_dict = OrderedDict()
|
76 |
+
state_dict = torch.load(embedding_path)
|
77 |
+
for k, v in state_dict.items():
|
78 |
+
#namekey = k[7:] # remove `module.`
|
79 |
+
namekey = k
|
80 |
+
new_state_dict[namekey] = v
|
81 |
+
embedding.load_state_dict(new_state_dict)
|
82 |
+
|
83 |
+
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
84 |
+
network.eval()
|
85 |
+
network.to(device)
|
86 |
+
|
87 |
+
|
88 |
+
|
89 |
+
content_tf = test_transform(content_size, crop)
|
90 |
+
style_tf = test_transform(style_size, crop)
|
91 |
+
|
92 |
+
|
93 |
+
content_tf1 = content_transform()
|
94 |
+
content = content_tf(content_img.convert("RGB"))
|
95 |
+
|
96 |
+
h,w,c=np.shape(content)
|
97 |
+
style_tf1 = style_transform(h,w)
|
98 |
+
style = style_tf(style_img.convert("RGB"))
|
99 |
+
|
100 |
+
|
101 |
+
style = style.to(device).unsqueeze(0)
|
102 |
+
content = content.to(device).unsqueeze(0)
|
103 |
|
104 |
+
with torch.no_grad():
|
105 |
+
output= network(content,style)
|
106 |
+
output = output[0].cpu()
|
107 |
+
output = transforms.ToPILImage(output)
|
108 |
+
return output
|
109 |
+
|
110 |
+
def create_styledSofa(sofa:Image, style:Image):
|
111 |
+
styled_sofa = transformer(sofa,style)
|
112 |
return styled_sofa
|
113 |
|
114 |
# image = Image.open('sofa_office.jpg')
|
test.py
DELETED
@@ -1,175 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
from pathlib import Path
|
3 |
-
import os
|
4 |
-
import torch
|
5 |
-
import torch.nn as nn
|
6 |
-
from PIL import Image
|
7 |
-
from os.path import basename
|
8 |
-
from os.path import splitext
|
9 |
-
from torchvision import transforms
|
10 |
-
from torchvision.utils import save_image
|
11 |
-
from function import calc_mean_std, normal, coral
|
12 |
-
import transformer as transformer
|
13 |
-
import StyTR as StyTR
|
14 |
-
import matplotlib.pyplot as plt
|
15 |
-
from matplotlib import cm
|
16 |
-
from function import normal
|
17 |
-
import numpy as np
|
18 |
-
|
19 |
-
def test_transform(size, crop):
|
20 |
-
transform_list = []
|
21 |
-
|
22 |
-
if size != 0:
|
23 |
-
transform_list.append(transforms.Resize(size))
|
24 |
-
if crop:
|
25 |
-
transform_list.append(transforms.CenterCrop(size))
|
26 |
-
transform_list.append(transforms.ToTensor())
|
27 |
-
transform = transforms.Compose(transform_list)
|
28 |
-
return transform
|
29 |
-
def style_transform(h,w):
|
30 |
-
k = (h,w)
|
31 |
-
size = int(np.max(k))
|
32 |
-
transform_list = []
|
33 |
-
transform_list.append(transforms.CenterCrop((h,w)))
|
34 |
-
transform_list.append(transforms.ToTensor())
|
35 |
-
transform = transforms.Compose(transform_list)
|
36 |
-
return transform
|
37 |
-
|
38 |
-
def content_transform():
|
39 |
-
|
40 |
-
transform_list = []
|
41 |
-
transform_list.append(transforms.ToTensor())
|
42 |
-
transform = transforms.Compose(transform_list)
|
43 |
-
return transform
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
parser = argparse.ArgumentParser()
|
48 |
-
# Basic options
|
49 |
-
parser.add_argument('--content', type=str,
|
50 |
-
help='File path to the content image')
|
51 |
-
parser.add_argument('--content_dir', type=str,
|
52 |
-
help='Directory path to a batch of content images')
|
53 |
-
parser.add_argument('--style', type=str,
|
54 |
-
help='File path to the style image, or multiple style \
|
55 |
-
images separated by commas if you want to do style \
|
56 |
-
interpolation or spatial control')
|
57 |
-
parser.add_argument('--style_dir', type=str,
|
58 |
-
help='Directory path to a batch of style images')
|
59 |
-
parser.add_argument('--output', type=str, default='output',
|
60 |
-
help='Directory to save the output image(s)')
|
61 |
-
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
|
62 |
-
parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
|
63 |
-
parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
|
64 |
-
parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
|
65 |
-
|
66 |
-
|
67 |
-
parser.add_argument('--style_interpolation_weights', type=str, default="")
|
68 |
-
parser.add_argument('--a', type=float, default=1.0)
|
69 |
-
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
70 |
-
help="Type of positional embedding to use on top of the image features")
|
71 |
-
parser.add_argument('--hidden_dim', default=512, type=int,
|
72 |
-
help="Size of the embeddings (dimension of the transformer)")
|
73 |
-
args = parser.parse_args()
|
74 |
-
|
75 |
-
# Advanced options
|
76 |
-
content_size=640
|
77 |
-
style_size=640
|
78 |
-
crop='store_true'
|
79 |
-
save_ext='.jpg'
|
80 |
-
output_path=args.output
|
81 |
-
preserve_color='store_true'
|
82 |
-
alpha=args.a
|
83 |
-
|
84 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
-
|
86 |
-
# Either --content or --content_dir should be given.
|
87 |
-
if args.content:
|
88 |
-
content_paths = [Path(args.content)]
|
89 |
-
else:
|
90 |
-
content_dir = Path(args.content_dir)
|
91 |
-
content_paths = [f for f in content_dir.glob('*')]
|
92 |
-
|
93 |
-
# Either --style or --style_dir should be given.
|
94 |
-
if args.style:
|
95 |
-
style_paths = [Path(args.style)]
|
96 |
-
else:
|
97 |
-
style_dir = Path(args.style_dir)
|
98 |
-
style_paths = [f for f in style_dir.glob('*')]
|
99 |
-
|
100 |
-
if not os.path.exists(output_path):
|
101 |
-
os.mkdir(output_path)
|
102 |
-
|
103 |
-
|
104 |
-
vgg = StyTR.vgg
|
105 |
-
vgg.load_state_dict(torch.load(args.vgg))
|
106 |
-
vgg = nn.Sequential(*list(vgg.children())[:44])
|
107 |
-
|
108 |
-
decoder = StyTR.decoder
|
109 |
-
Trans = transformer.Transformer()
|
110 |
-
embedding = StyTR.PatchEmbed()
|
111 |
-
|
112 |
-
decoder.eval()
|
113 |
-
Trans.eval()
|
114 |
-
vgg.eval()
|
115 |
-
from collections import OrderedDict
|
116 |
-
new_state_dict = OrderedDict()
|
117 |
-
state_dict = torch.load(args.decoder_path)
|
118 |
-
for k, v in state_dict.items():
|
119 |
-
#namekey = k[7:] # remove `module.`
|
120 |
-
namekey = k
|
121 |
-
new_state_dict[namekey] = v
|
122 |
-
decoder.load_state_dict(new_state_dict)
|
123 |
-
|
124 |
-
new_state_dict = OrderedDict()
|
125 |
-
state_dict = torch.load(args.Trans_path)
|
126 |
-
for k, v in state_dict.items():
|
127 |
-
#namekey = k[7:] # remove `module.`
|
128 |
-
namekey = k
|
129 |
-
new_state_dict[namekey] = v
|
130 |
-
Trans.load_state_dict(new_state_dict)
|
131 |
-
|
132 |
-
new_state_dict = OrderedDict()
|
133 |
-
state_dict = torch.load(args.embedding_path)
|
134 |
-
for k, v in state_dict.items():
|
135 |
-
#namekey = k[7:] # remove `module.`
|
136 |
-
namekey = k
|
137 |
-
new_state_dict[namekey] = v
|
138 |
-
embedding.load_state_dict(new_state_dict)
|
139 |
-
|
140 |
-
network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
|
141 |
-
network.eval()
|
142 |
-
network.to(device)
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
content_tf = test_transform(content_size, crop)
|
147 |
-
style_tf = test_transform(style_size, crop)
|
148 |
-
|
149 |
-
for content_path in content_paths:
|
150 |
-
for style_path in style_paths:
|
151 |
-
|
152 |
-
|
153 |
-
content_tf1 = content_transform()
|
154 |
-
content = content_tf(Image.open(content_path).convert("RGB"))
|
155 |
-
|
156 |
-
h,w,c=np.shape(content)
|
157 |
-
style_tf1 = style_transform(h,w)
|
158 |
-
style = style_tf(Image.open(style_path).convert("RGB"))
|
159 |
-
|
160 |
-
|
161 |
-
style = style.to(device).unsqueeze(0)
|
162 |
-
content = content.to(device).unsqueeze(0)
|
163 |
-
|
164 |
-
with torch.no_grad():
|
165 |
-
output= network(content,style)
|
166 |
-
output = output[0].cpu()
|
167 |
-
|
168 |
-
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
|
169 |
-
output_path, splitext(basename(content_path))[0],
|
170 |
-
splitext(basename(style_path))[0], save_ext
|
171 |
-
)
|
172 |
-
|
173 |
-
save_image(output, output_name)
|
174 |
-
|
175 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|