Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
ab7b996
1
Parent(s):
cfa3c46
fix errors
Browse files- styleTransfer.py +10 -37
styleTransfer.py
CHANGED
@@ -11,8 +11,6 @@ from collections import OrderedDict
|
|
11 |
import tensorflow_hub as hub
|
12 |
import tensorflow as tf
|
13 |
|
14 |
-
from torchvision.utils import save_image
|
15 |
-
|
16 |
############################################# TRANSFORMER ############################################
|
17 |
|
18 |
vgg_path = 'vgg_normalised.pth'
|
@@ -20,16 +18,6 @@ decoder_path = 'decoder_iter_160000.pth'
|
|
20 |
Trans_path = 'transformer_iter_160000.pth'
|
21 |
embedding_path = 'embedding_iter_160000.pth'
|
22 |
|
23 |
-
def test_transform(size, crop):
|
24 |
-
transform_list = []
|
25 |
-
|
26 |
-
if size != 0:
|
27 |
-
transform_list.append(transforms.Resize(size))
|
28 |
-
if crop:
|
29 |
-
transform_list.append(transforms.CenterCrop(size))
|
30 |
-
transform_list.append(transforms.ToTensor())
|
31 |
-
transform = transforms.Compose(transform_list)
|
32 |
-
return transform
|
33 |
def style_transform(h,w):
|
34 |
k = (h,w)
|
35 |
size = int(np.max(k))
|
@@ -48,7 +36,6 @@ def content_transform():
|
|
48 |
# Advanced options
|
49 |
content_size=640
|
50 |
style_size=640
|
51 |
-
crop='store_true'
|
52 |
|
53 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
|
@@ -66,42 +53,28 @@ vgg.eval()
|
|
66 |
|
67 |
new_state_dict = OrderedDict()
|
68 |
state_dict = torch.load(decoder_path)
|
69 |
-
|
70 |
-
#namekey = k[7:] # remove `module.`
|
71 |
-
namekey = k
|
72 |
-
new_state_dict[namekey] = v
|
73 |
-
decoder.load_state_dict(new_state_dict)
|
74 |
|
75 |
new_state_dict = OrderedDict()
|
76 |
state_dict = torch.load(Trans_path)
|
77 |
-
|
78 |
-
#namekey = k[7:] # remove `module.`
|
79 |
-
namekey = k
|
80 |
-
new_state_dict[namekey] = v
|
81 |
-
Trans.load_state_dict(new_state_dict)
|
82 |
|
83 |
new_state_dict = OrderedDict()
|
84 |
state_dict = torch.load(embedding_path)
|
85 |
-
|
86 |
-
#namekey = k[7:] # remove `module.`
|
87 |
-
namekey = k
|
88 |
-
new_state_dict[namekey] = v
|
89 |
-
embedding.load_state_dict(new_state_dict)
|
90 |
|
91 |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
92 |
network.eval()
|
93 |
|
|
|
|
|
|
|
94 |
def StyleTransformer(content_img: Image, style_img: Image):
|
95 |
|
96 |
network.to(device)
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
content_tf1 = content_transform()
|
101 |
-
content = content_tf1(content_img.convert("RGB"))
|
102 |
-
h,w,c=np.shape(content)
|
103 |
-
style_tf1 = style_transform(h,w)
|
104 |
-
style = style_tf1(style_img.convert("RGB"))
|
105 |
style = style.to(device).unsqueeze(0)
|
106 |
content = content.to(device).unsqueeze(0)
|
107 |
|
@@ -128,4 +101,4 @@ def StyleGAN(content_image, style_image):
|
|
128 |
def create_styledSofa(sofa:Image, style:Image):
|
129 |
#styled_sofa = StyleGAN(sofa,style)
|
130 |
styled_sofa = StyleTransformer(sofa,style)
|
131 |
-
return styled_sofa
|
11 |
import tensorflow_hub as hub
|
12 |
import tensorflow as tf
|
13 |
|
|
|
|
|
14 |
############################################# TRANSFORMER ############################################
|
15 |
|
16 |
vgg_path = 'vgg_normalised.pth'
|
18 |
Trans_path = 'transformer_iter_160000.pth'
|
19 |
embedding_path = 'embedding_iter_160000.pth'
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
def style_transform(h,w):
|
22 |
k = (h,w)
|
23 |
size = int(np.max(k))
|
36 |
# Advanced options
|
37 |
content_size=640
|
38 |
style_size=640
|
|
|
39 |
|
40 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
41 |
|
53 |
|
54 |
new_state_dict = OrderedDict()
|
55 |
state_dict = torch.load(decoder_path)
|
56 |
+
decoder.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
57 |
|
58 |
new_state_dict = OrderedDict()
|
59 |
state_dict = torch.load(Trans_path)
|
60 |
+
Trans.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
61 |
|
62 |
new_state_dict = OrderedDict()
|
63 |
state_dict = torch.load(embedding_path)
|
64 |
+
embedding.load_state_dict(state_dict)
|
|
|
|
|
|
|
|
|
65 |
|
66 |
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
67 |
network.eval()
|
68 |
|
69 |
+
content_tf = content_transform()
|
70 |
+
style_tf = style_transform(style_size,style_size)
|
71 |
+
|
72 |
def StyleTransformer(content_img: Image, style_img: Image):
|
73 |
|
74 |
network.to(device)
|
75 |
+
|
76 |
+
content = content_tf(content_img.convert("RGB"))
|
77 |
+
style = style_tf(style_img.convert("RGB"))
|
|
|
|
|
|
|
|
|
|
|
78 |
style = style.to(device).unsqueeze(0)
|
79 |
content = content.to(device).unsqueeze(0)
|
80 |
|
101 |
def create_styledSofa(sofa:Image, style:Image):
|
102 |
#styled_sofa = StyleGAN(sofa,style)
|
103 |
styled_sofa = StyleTransformer(sofa,style)
|
104 |
+
return styled_sofa
|