Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
8e6efc1
1
Parent(s):
a402a5e
make transformer global
Browse files- app.py +3 -1
- styleTransfer.py +67 -72
app.py
CHANGED
@@ -90,10 +90,12 @@ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
|
90 |
input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
|
91 |
resized_img,box = resize_sofa(input_img)
|
92 |
resized_style = resize_style(style_img)
|
|
|
93 |
# generate mask for image
|
94 |
mask = get_mask(resized_img)
|
|
|
95 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
96 |
-
|
97 |
# postprocess the final image
|
98 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
99 |
new_sofa = new_sofa.crop(box)
|
90 |
input_img,style_img = Image.fromarray(input_img),Image.fromarray(style_img)
|
91 |
resized_img,box = resize_sofa(input_img)
|
92 |
resized_style = resize_style(style_img)
|
93 |
+
resized_style.save('resized_style.jpg')
|
94 |
# generate mask for image
|
95 |
mask = get_mask(resized_img)
|
96 |
+
mask.save('mask.jpg')
|
97 |
styled_sofa = create_styledSofa(resized_img,resized_style)
|
98 |
+
styled_sofa.save('styled_sofa.jpg')
|
99 |
# postprocess the final image
|
100 |
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
101 |
new_sofa = new_sofa.crop(box)
|
styleTransfer.py
CHANGED
@@ -11,10 +11,15 @@ from collections import OrderedDict
|
|
11 |
import tensorflow_hub as hub
|
12 |
import tensorflow as tf
|
13 |
|
14 |
-
|
15 |
-
|
16 |
|
17 |
############################################# TRANSFORMER ############################################
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
def test_transform(size, crop):
|
19 |
transform_list = []
|
20 |
|
@@ -40,82 +45,75 @@ def content_transform():
|
|
40 |
transform = transforms.Compose(transform_list)
|
41 |
return transform
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
style_tf = test_transform(style_size, crop)
|
95 |
-
|
96 |
content_tf1 = content_transform()
|
97 |
content = content_tf(content_img.convert("RGB"))
|
98 |
-
|
99 |
h,w,c=np.shape(content)
|
100 |
style_tf1 = style_transform(h,w)
|
101 |
style = style_tf(style_img.convert("RGB"))
|
102 |
-
|
103 |
-
|
104 |
style = style.to(device).unsqueeze(0)
|
105 |
content = content.to(device).unsqueeze(0)
|
106 |
|
107 |
with torch.no_grad():
|
108 |
output= network(content,style)
|
109 |
-
|
110 |
-
output = output
|
111 |
-
|
112 |
-
print(output.squeeze().shape)
|
113 |
-
torch2PIL = transforms.ToPILImage()
|
114 |
-
output = torch2PIL(output.squeeze())
|
115 |
-
return output
|
116 |
|
117 |
############################################## STYLE-GAN #############################################
|
118 |
|
|
|
|
|
119 |
def perform_style_transfer(content_image, style_image):
|
120 |
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
|
121 |
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
|
@@ -130,10 +128,7 @@ def create_styledSofa(sofa:Image, style:Image):
|
|
130 |
styled_sofa = StyleTransformer(sofa,style)
|
131 |
return styled_sofa
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
# image = image.crop(box)
|
138 |
-
# print(box)
|
139 |
-
# image.show()
|
11 |
import tensorflow_hub as hub
|
12 |
import tensorflow as tf
|
13 |
|
14 |
+
from torchvision.utils import save_image
|
|
|
15 |
|
16 |
############################################# TRANSFORMER ############################################
|
17 |
+
|
18 |
+
vgg_path = 'vgg_normalised.pth'
|
19 |
+
decoder_path = 'decoder_iter_160000.pth'
|
20 |
+
Trans_path = 'transformer_iter_160000.pth'
|
21 |
+
embedding_path = 'embedding_iter_160000.pth'
|
22 |
+
|
23 |
def test_transform(size, crop):
|
24 |
transform_list = []
|
25 |
|
45 |
transform = transforms.Compose(transform_list)
|
46 |
return transform
|
47 |
|
48 |
+
# Advanced options
|
49 |
+
content_size=640
|
50 |
+
style_size=640
|
51 |
+
crop='store_true'
|
52 |
+
|
53 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
54 |
+
|
55 |
+
vgg = StyTR.vgg
|
56 |
+
vgg.load_state_dict(torch.load(vgg_path))
|
57 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
58 |
+
|
59 |
+
decoder = StyTR.decoder
|
60 |
+
Trans = transformer.Transformer()
|
61 |
+
embedding = StyTR.PatchEmbed()
|
62 |
+
|
63 |
+
decoder.eval()
|
64 |
+
Trans.eval()
|
65 |
+
vgg.eval()
|
66 |
+
|
67 |
+
new_state_dict = OrderedDict()
|
68 |
+
state_dict = torch.load(decoder_path)
|
69 |
+
for k, v in state_dict.items():
|
70 |
+
#namekey = k[7:] # remove `module.`
|
71 |
+
namekey = k
|
72 |
+
new_state_dict[namekey] = v
|
73 |
+
decoder.load_state_dict(new_state_dict)
|
74 |
+
|
75 |
+
new_state_dict = OrderedDict()
|
76 |
+
state_dict = torch.load(Trans_path)
|
77 |
+
for k, v in state_dict.items():
|
78 |
+
#namekey = k[7:] # remove `module.`
|
79 |
+
namekey = k
|
80 |
+
new_state_dict[namekey] = v
|
81 |
+
Trans.load_state_dict(new_state_dict)
|
82 |
+
|
83 |
+
new_state_dict = OrderedDict()
|
84 |
+
state_dict = torch.load(embedding_path)
|
85 |
+
for k, v in state_dict.items():
|
86 |
+
#namekey = k[7:] # remove `module.`
|
87 |
+
namekey = k
|
88 |
+
new_state_dict[namekey] = v
|
89 |
+
embedding.load_state_dict(new_state_dict)
|
90 |
+
|
91 |
+
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
92 |
+
network.eval()
|
93 |
+
network.to(device)
|
94 |
+
|
95 |
+
content_tf = test_transform(content_size, crop)
|
96 |
+
style_tf = test_transform(style_size, crop)
|
97 |
+
|
98 |
+
def StyleTransformer(content_img: Image, style_img: Image):
|
|
|
|
|
99 |
content_tf1 = content_transform()
|
100 |
content = content_tf(content_img.convert("RGB"))
|
|
|
101 |
h,w,c=np.shape(content)
|
102 |
style_tf1 = style_transform(h,w)
|
103 |
style = style_tf(style_img.convert("RGB"))
|
|
|
|
|
104 |
style = style.to(device).unsqueeze(0)
|
105 |
content = content.to(device).unsqueeze(0)
|
106 |
|
107 |
with torch.no_grad():
|
108 |
output= network(content,style)
|
109 |
+
output = output[0].cpu().squeeze()
|
110 |
+
output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
|
111 |
+
return Image.fromarray(output)
|
|
|
|
|
|
|
|
|
112 |
|
113 |
############################################## STYLE-GAN #############################################
|
114 |
|
115 |
+
style_transfer_model = hub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
|
116 |
+
|
117 |
def perform_style_transfer(content_image, style_image):
|
118 |
content_image = tf.convert_to_tensor(content_image, np.float32)[tf.newaxis, ...] / 255.
|
119 |
style_image = tf.convert_to_tensor(style_image, np.float32)[tf.newaxis, ...] / 255.
|
128 |
styled_sofa = StyleTransformer(sofa,style)
|
129 |
return styled_sofa
|
130 |
|
131 |
+
image = Image.open('sofa.jpg')
|
132 |
+
style = Image.open('style.jpg')
|
133 |
+
output = create_styledSofa(image,style)
|
134 |
+
output.show()
|
|
|
|
|
|