Spaces:
Build error
Build error
Sophie98
commited on
Commit
β’
ad1ac8f
1
Parent(s):
6048967
change to streamlit
Browse files- .flake8 +0 -12
- .streamlit/config.toml +3 -0
- README.md +2 -2
- Segmentation/{model_checkpoint.h5 β model_final.h5} +2 -2
- Segmentation/segmentation.py +24 -39
- StyleTransfer/{StyTR.py β srcTransformer/StyTR.py} +89 -60
- StyleTransfer/{models β srcTransformer/Transformer_models}/decoder_iter_160000.pth +0 -0
- StyleTransfer/{models β srcTransformer/Transformer_models}/embedding_iter_160000.pth +0 -0
- StyleTransfer/{models β srcTransformer/Transformer_models}/transformer_iter_160000.pth +0 -0
- StyleTransfer/{models β srcTransformer/Transformer_models}/vgg_normalised.pth +0 -0
- StyleTransfer/{ViT_helper.py β srcTransformer/ViT_helper.py} +53 -35
- StyleTransfer/srcTransformer/__init__.py +0 -0
- StyleTransfer/{function.py β srcTransformer/function.py} +27 -22
- StyleTransfer/{misc.py β srcTransformer/misc.py} +114 -71
- StyleTransfer/{transformer.py β srcTransformer/transformer.py} +213 -119
- StyleTransfer/styleTransfer.py +136 -70
- app.py +290 -134
- {gradio_cached_examples/output β figures}/0.png +0 -0
- {gradio_cached_examples/output β figures}/1.png +0 -0
- {gradio_cached_examples/output β figures}/2.png +0 -0
- figures/StyleGANsofa.png +0 -0
- figures/Transformersofa.jpg +0 -0
- figures/logo.png +0 -0
- gradio_cached_examples/log.csv +0 -4
- packages.txt +0 -3
- requirements.txt +3 -3
.flake8
DELETED
@@ -1,12 +0,0 @@
|
|
1 |
-
[flake8]
|
2 |
-
exclude =
|
3 |
-
.git,
|
4 |
-
*.egg-info,
|
5 |
-
__pycache__,
|
6 |
-
.tox,
|
7 |
-
.pytest_cache,
|
8 |
-
build,
|
9 |
-
dist,
|
10 |
-
tests
|
11 |
-
max-line-length = 88
|
12 |
-
ignore = D202,W503,E203 # conflicts with black
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.streamlit/config.toml
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[theme]
|
2 |
+
base="dark"
|
3 |
+
primaryColor="#04b188"
|
README.md
CHANGED
@@ -3,8 +3,8 @@ title: SofaStyler
|
|
3 |
emoji: π
|
4 |
colorFrom: blue
|
5 |
colorTo: green
|
6 |
-
sdk:
|
7 |
-
sdk_version:
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
|
|
3 |
emoji: π
|
4 |
colorFrom: blue
|
5 |
colorTo: green
|
6 |
+
sdk: streamlit
|
7 |
+
sdk_version: 1.9.0
|
8 |
app_file: app.py
|
9 |
pinned: false
|
10 |
---
|
Segmentation/{model_checkpoint.h5 β model_final.h5}
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a456f38c83897d9d8b5c8dd989ff7ee2fe13bb123a70a00b6e987d4efac1c6e
|
3 |
+
size 130858696
|
Segmentation/segmentation.py
CHANGED
@@ -1,60 +1,45 @@
|
|
|
|
|
|
1 |
import cv2
|
2 |
from tensorflow import keras
|
3 |
import numpy as np
|
4 |
from PIL import Image
|
5 |
import segmentation_models as sm
|
6 |
-
sm.set_framework('tf.keras')
|
7 |
|
8 |
-
|
9 |
-
model_path = "Segmentation/model_checkpoint.h5"
|
10 |
-
CLASSES = ['sofa']
|
11 |
-
BACKBONE = 'resnet50'
|
12 |
|
13 |
-
#
|
14 |
-
|
15 |
-
activation = 'sigmoid' if n_classes == 1 else 'softmax'
|
16 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
17 |
-
|
18 |
-
|
19 |
-
model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
|
20 |
-
# define optomizer
|
21 |
-
optim = keras.optimizers.Adam(LR)
|
22 |
-
dice_loss = sm.losses.DiceLoss()
|
23 |
-
focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
|
24 |
-
total_loss = dice_loss + (1 * focal_loss)
|
25 |
-
metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
|
26 |
-
# compile keras model with defined optimozer, loss and metrics
|
27 |
-
model.compile(optim, total_loss, metrics)
|
28 |
-
model.load_weights(model_path)
|
29 |
|
30 |
-
def get_mask(image:Image
|
31 |
"""
|
32 |
-
This function generates a mask of the image that highlights all the sofas
|
33 |
-
This uses a pre-trained Unet model with a resnet50 backbone.
|
34 |
-
Remark: The model was trained on 640by640 images and it is therefore best
|
|
|
|
|
35 |
Parameters:
|
36 |
image = original image
|
37 |
Return:
|
38 |
mask = corresponding maks of the image
|
39 |
"""
|
40 |
-
|
41 |
-
# #load model
|
42 |
-
|
43 |
-
#model = keras.models.load_model('model_final.h5', compile=False)
|
44 |
-
print('loaded model')
|
45 |
-
test_img = np.array(image)
|
46 |
test_img = cv2.resize(test_img, (640, 640))
|
47 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
48 |
test_img = np.expand_dims(test_img, axis=0)
|
49 |
|
50 |
prediction = model.predict(preprocess_input(np.array(test_img))).round()
|
51 |
-
mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
|
52 |
-
return mask
|
|
|
53 |
|
54 |
-
def replace_sofa(image:Image
|
55 |
"""
|
56 |
-
This function replaces the original sofa in the image by the new styled
|
57 |
-
to the mask.
|
58 |
Remark: All images should have the same size.
|
59 |
Input:
|
60 |
image = Original image
|
@@ -63,11 +48,11 @@ def replace_sofa(image:Image.Image, mask:Image.Image, styled_sofa:Image.Image) -
|
|
63 |
Return:
|
64 |
new_image = Image containing the styled sofa
|
65 |
"""
|
66 |
-
image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa)
|
67 |
|
68 |
_, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
|
69 |
mask_inv = cv2.bitwise_not(mask)
|
70 |
-
image_bg = cv2.bitwise_and(image,image,mask
|
71 |
-
sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask
|
72 |
-
new_image = cv2.add(image_bg,sofa_fg)
|
73 |
return Image.fromarray(new_image)
|
|
|
1 |
+
# Import libraries
|
2 |
+
|
3 |
import cv2
|
4 |
from tensorflow import keras
|
5 |
import numpy as np
|
6 |
from PIL import Image
|
7 |
import segmentation_models as sm
|
|
|
8 |
|
9 |
+
sm.set_framework("tf.keras")
|
|
|
|
|
|
|
10 |
|
11 |
+
# Load segmentation model
|
12 |
+
BACKBONE = "resnet50"
|
|
|
13 |
preprocess_input = sm.get_preprocessing(BACKBONE)
|
14 |
+
model = keras.models.load_model("Segmentation/model_final.h5", compile=False)
|
15 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
def get_mask(image: Image) -> Image:
|
18 |
"""
|
19 |
+
This function generates a mask of the image that highlights all the sofas
|
20 |
+
in the image. This uses a pre-trained Unet model with a resnet50 backbone.
|
21 |
+
Remark: The model was trained on 640by640 images and it is therefore best
|
22 |
+
that the image has the same size.
|
23 |
+
|
24 |
Parameters:
|
25 |
image = original image
|
26 |
Return:
|
27 |
mask = corresponding maks of the image
|
28 |
"""
|
29 |
+
test_img = np.array(image)
|
|
|
|
|
|
|
|
|
|
|
30 |
test_img = cv2.resize(test_img, (640, 640))
|
31 |
test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
|
32 |
test_img = np.expand_dims(test_img, axis=0)
|
33 |
|
34 |
prediction = model.predict(preprocess_input(np.array(test_img))).round()
|
35 |
+
mask = Image.fromarray(prediction[..., 0].squeeze() * 255).convert("L")
|
36 |
+
return mask
|
37 |
+
|
38 |
|
39 |
+
def replace_sofa(image: Image, mask: Image, styled_sofa: Image) -> Image:
|
40 |
"""
|
41 |
+
This function replaces the original sofa in the image by the new styled
|
42 |
+
sofa according to the mask.
|
43 |
Remark: All images should have the same size.
|
44 |
Input:
|
45 |
image = Original image
|
|
|
48 |
Return:
|
49 |
new_image = Image containing the styled sofa
|
50 |
"""
|
51 |
+
image, mask, styled_sofa = np.array(image), np.array(mask), np.array(styled_sofa)
|
52 |
|
53 |
_, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
|
54 |
mask_inv = cv2.bitwise_not(mask)
|
55 |
+
image_bg = cv2.bitwise_and(image, image, mask=mask_inv)
|
56 |
+
sofa_fg = cv2.bitwise_and(styled_sofa, styled_sofa, mask=mask)
|
57 |
+
new_image = cv2.add(image_bg, sofa_fg)
|
58 |
return Image.fromarray(new_image)
|
StyleTransfer/{StyTR.py β srcTransformer/StyTR.py}
RENAMED
@@ -1,17 +1,24 @@
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
from torch import nn
|
4 |
-
|
5 |
-
accuracy, get_world_size, interpolate,
|
6 |
-
is_dist_avail_and_initialized)
|
7 |
-
from StyleTransfer.function import normal,normal_style
|
8 |
-
from StyleTransfer.function import calc_mean_std
|
9 |
-
from StyleTransfer.ViT_helper import DropPath, to_2tuple, trunc_normal_
|
10 |
|
11 |
class PatchEmbed(nn.Module):
|
12 |
-
"""
|
13 |
-
|
14 |
-
def __init__(
|
|
|
|
|
|
|
|
|
|
|
|
|
15 |
super().__init__()
|
16 |
img_size = to_2tuple(img_size)
|
17 |
patch_size = to_2tuple(patch_size)
|
@@ -19,9 +26,11 @@ class PatchEmbed(nn.Module):
|
|
19 |
self.img_size = img_size
|
20 |
self.patch_size = patch_size
|
21 |
self.num_patches = num_patches
|
22 |
-
|
23 |
-
self.proj = nn.Conv2d(
|
24 |
-
|
|
|
|
|
25 |
|
26 |
def forward(self, x):
|
27 |
B, C, H, W = x.shape
|
@@ -34,7 +43,7 @@ decoder = nn.Sequential(
|
|
34 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
35 |
nn.Conv2d(512, 256, (3, 3)),
|
36 |
nn.ReLU(),
|
37 |
-
nn.Upsample(scale_factor=2, mode=
|
38 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
39 |
nn.Conv2d(256, 256, (3, 3)),
|
40 |
nn.ReLU(),
|
@@ -47,14 +56,14 @@ decoder = nn.Sequential(
|
|
47 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
48 |
nn.Conv2d(256, 128, (3, 3)),
|
49 |
nn.ReLU(),
|
50 |
-
nn.Upsample(scale_factor=2, mode=
|
51 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
52 |
nn.Conv2d(128, 128, (3, 3)),
|
53 |
nn.ReLU(),
|
54 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
55 |
nn.Conv2d(128, 64, (3, 3)),
|
56 |
nn.ReLU(),
|
57 |
-
nn.Upsample(scale_factor=2, mode=
|
58 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
59 |
nn.Conv2d(64, 64, (3, 3)),
|
60 |
nn.ReLU(),
|
@@ -115,26 +124,35 @@ vgg = nn.Sequential(
|
|
115 |
nn.ReLU(), # relu5-3
|
116 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
117 |
nn.Conv2d(512, 512, (3, 3)),
|
118 |
-
nn.ReLU() # relu5-4
|
119 |
)
|
120 |
|
|
|
121 |
class MLP(nn.Module):
|
122 |
-
"""
|
123 |
|
124 |
-
def __init__(
|
|
|
|
|
125 |
super().__init__()
|
126 |
self.num_layers = num_layers
|
127 |
h = [hidden_dim] * (num_layers - 1)
|
128 |
-
self.layers = nn.ModuleList(
|
|
|
|
|
129 |
|
130 |
def forward(self, x):
|
131 |
for i, layer in enumerate(self.layers):
|
132 |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
133 |
return x
|
|
|
|
|
134 |
class StyTrans(nn.Module):
|
135 |
-
"""
|
136 |
-
|
137 |
-
def __init__(
|
|
|
|
|
138 |
|
139 |
super().__init__()
|
140 |
enc_layers = list(encoder.children())
|
@@ -143,85 +161,96 @@ class StyTrans(nn.Module):
|
|
143 |
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
144 |
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
145 |
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
|
146 |
-
|
147 |
-
for name in [
|
148 |
for param in getattr(self, name).parameters():
|
149 |
param.requires_grad = False
|
150 |
|
151 |
self.mse_loss = nn.MSELoss()
|
152 |
self.transformer = transformer
|
153 |
-
hidden_dim = transformer.d_model
|
154 |
self.decode = decoder
|
155 |
self.embedding = PatchEmbed
|
156 |
|
157 |
def encode_with_intermediate(self, input):
|
158 |
results = [input]
|
159 |
for i in range(5):
|
160 |
-
func = getattr(self,
|
161 |
results.append(func(results[-1]))
|
162 |
return results[1:]
|
163 |
|
164 |
def calc_content_loss(self, input, target):
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
|
169 |
def calc_style_loss(self, input, target):
|
170 |
-
assert
|
171 |
-
assert
|
172 |
input_mean, input_std = calc_mean_std(input)
|
173 |
target_mean, target_std = calc_mean_std(target)
|
174 |
-
return self.mse_loss(input_mean, target_mean) +
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
|
|
|
|
|
|
180 |
|
181 |
"""
|
182 |
content_input = samples_c
|
183 |
style_input = samples_s
|
184 |
if isinstance(samples_c, (list, torch.Tensor)):
|
185 |
-
samples_c = nested_tensor_from_tensor_list(
|
|
|
|
|
186 |
if isinstance(samples_s, (list, torch.Tensor)):
|
187 |
-
samples_s = nested_tensor_from_tensor_list(samples_s)
|
188 |
-
|
189 |
-
#
|
190 |
content_feats = self.encode_with_intermediate(samples_c.tensors)
|
191 |
style_feats = self.encode_with_intermediate(samples_s.tensors)
|
192 |
|
193 |
-
|
194 |
style = self.embedding(samples_s.tensors)
|
195 |
content = self.embedding(samples_c.tensors)
|
196 |
-
|
197 |
# postional embedding is calculated in transformer.py
|
198 |
pos_s = None
|
199 |
pos_c = None
|
200 |
|
201 |
mask = None
|
202 |
-
hs = self.transformer(style, mask
|
203 |
Ics = self.decode(hs)
|
204 |
|
205 |
Ics_feats = self.encode_with_intermediate(Ics)
|
206 |
-
loss_c = self.calc_content_loss(
|
|
|
|
|
207 |
# Style loss
|
208 |
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
|
209 |
for i in range(1, 5):
|
210 |
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
|
211 |
-
|
212 |
-
|
213 |
-
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
221 |
-
|
222 |
-
|
|
|
|
|
|
|
223 |
for i in range(1, 5):
|
224 |
-
loss_lambda2 += self.calc_content_loss(
|
|
|
|
|
225 |
# Please select and comment out one of the following two sentences
|
226 |
-
return Ics,
|
227 |
-
# return Ics #test
|
|
|
1 |
import torch
|
2 |
import torch.nn.functional as F
|
3 |
+
from StyleTransfer.srcTransformer.function import calc_mean_std, normal
|
4 |
+
from StyleTransfer.srcTransformer.misc import (
|
5 |
+
NestedTensor,
|
6 |
+
nested_tensor_from_tensor_list,
|
7 |
+
)
|
8 |
+
from StyleTransfer.srcTransformer.ViT_helper import to_2tuple
|
9 |
from torch import nn
|
10 |
+
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
class PatchEmbed(nn.Module):
|
13 |
+
"""Image to Patch Embedding"""
|
14 |
+
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
img_size: int = 256,
|
18 |
+
patch_size: int = 8,
|
19 |
+
in_chans: int = 3,
|
20 |
+
embed_dim: int = 512,
|
21 |
+
):
|
22 |
super().__init__()
|
23 |
img_size = to_2tuple(img_size)
|
24 |
patch_size = to_2tuple(patch_size)
|
|
|
26 |
self.img_size = img_size
|
27 |
self.patch_size = patch_size
|
28 |
self.num_patches = num_patches
|
29 |
+
|
30 |
+
self.proj = nn.Conv2d(
|
31 |
+
in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
|
32 |
+
)
|
33 |
+
self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
|
34 |
|
35 |
def forward(self, x):
|
36 |
B, C, H, W = x.shape
|
|
|
43 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
44 |
nn.Conv2d(512, 256, (3, 3)),
|
45 |
nn.ReLU(),
|
46 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
47 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
48 |
nn.Conv2d(256, 256, (3, 3)),
|
49 |
nn.ReLU(),
|
|
|
56 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
57 |
nn.Conv2d(256, 128, (3, 3)),
|
58 |
nn.ReLU(),
|
59 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
60 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
61 |
nn.Conv2d(128, 128, (3, 3)),
|
62 |
nn.ReLU(),
|
63 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
64 |
nn.Conv2d(128, 64, (3, 3)),
|
65 |
nn.ReLU(),
|
66 |
+
nn.Upsample(scale_factor=2, mode="nearest"),
|
67 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
68 |
nn.Conv2d(64, 64, (3, 3)),
|
69 |
nn.ReLU(),
|
|
|
124 |
nn.ReLU(), # relu5-3
|
125 |
nn.ReflectionPad2d((1, 1, 1, 1)),
|
126 |
nn.Conv2d(512, 512, (3, 3)),
|
127 |
+
nn.ReLU(), # relu5-4
|
128 |
)
|
129 |
|
130 |
+
|
131 |
class MLP(nn.Module):
|
132 |
+
"""Very simple multi-layer perceptron (also called FFN)"""
|
133 |
|
134 |
+
def __init__(
|
135 |
+
self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
|
136 |
+
):
|
137 |
super().__init__()
|
138 |
self.num_layers = num_layers
|
139 |
h = [hidden_dim] * (num_layers - 1)
|
140 |
+
self.layers = nn.ModuleList(
|
141 |
+
nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
|
142 |
+
)
|
143 |
|
144 |
def forward(self, x):
|
145 |
for i, layer in enumerate(self.layers):
|
146 |
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
147 |
return x
|
148 |
+
|
149 |
+
|
150 |
class StyTrans(nn.Module):
|
151 |
+
"""This is the style transform transformer module"""
|
152 |
+
|
153 |
+
def __init__(
|
154 |
+
self, encoder: nn.Sequential, decoder: nn.Sequential, PatchEmbed, transformer
|
155 |
+
):
|
156 |
|
157 |
super().__init__()
|
158 |
enc_layers = list(encoder.children())
|
|
|
161 |
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
162 |
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
163 |
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
|
164 |
+
|
165 |
+
for name in ["enc_1", "enc_2", "enc_3", "enc_4", "enc_5"]:
|
166 |
for param in getattr(self, name).parameters():
|
167 |
param.requires_grad = False
|
168 |
|
169 |
self.mse_loss = nn.MSELoss()
|
170 |
self.transformer = transformer
|
|
|
171 |
self.decode = decoder
|
172 |
self.embedding = PatchEmbed
|
173 |
|
174 |
def encode_with_intermediate(self, input):
|
175 |
results = [input]
|
176 |
for i in range(5):
|
177 |
+
func = getattr(self, "enc_{:d}".format(i + 1))
|
178 |
results.append(func(results[-1]))
|
179 |
return results[1:]
|
180 |
|
181 |
def calc_content_loss(self, input, target):
|
182 |
+
assert input.size() == target.size()
|
183 |
+
assert target.requires_grad is False
|
184 |
+
return self.mse_loss(input, target)
|
185 |
|
186 |
def calc_style_loss(self, input, target):
|
187 |
+
assert input.size() == target.size()
|
188 |
+
assert target.requires_grad is False
|
189 |
input_mean, input_std = calc_mean_std(input)
|
190 |
target_mean, target_std = calc_mean_std(target)
|
191 |
+
return self.mse_loss(input_mean, target_mean) + self.mse_loss(
|
192 |
+
input_std, target_std
|
193 |
+
)
|
194 |
+
|
195 |
+
def forward(self, samples_c: NestedTensor, samples_s: NestedTensor):
|
196 |
+
"""The forward expects a NestedTensor, which consists of:
|
197 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
198 |
+
- samples.mask: a binary mask of shape [batch_size x H x W],
|
199 |
+
containing 1 on padded pixels
|
200 |
|
201 |
"""
|
202 |
content_input = samples_c
|
203 |
style_input = samples_s
|
204 |
if isinstance(samples_c, (list, torch.Tensor)):
|
205 |
+
samples_c = nested_tensor_from_tensor_list(
|
206 |
+
samples_c
|
207 |
+
) # support different-sized images padding is used for mask [tensor, mask]
|
208 |
if isinstance(samples_s, (list, torch.Tensor)):
|
209 |
+
samples_s = nested_tensor_from_tensor_list(samples_s)
|
210 |
+
|
211 |
+
# features used to calcate loss
|
212 |
content_feats = self.encode_with_intermediate(samples_c.tensors)
|
213 |
style_feats = self.encode_with_intermediate(samples_s.tensors)
|
214 |
|
215 |
+
# Linear projection
|
216 |
style = self.embedding(samples_s.tensors)
|
217 |
content = self.embedding(samples_c.tensors)
|
218 |
+
|
219 |
# postional embedding is calculated in transformer.py
|
220 |
pos_s = None
|
221 |
pos_c = None
|
222 |
|
223 |
mask = None
|
224 |
+
hs = self.transformer(style, mask, content, pos_c, pos_s)
|
225 |
Ics = self.decode(hs)
|
226 |
|
227 |
Ics_feats = self.encode_with_intermediate(Ics)
|
228 |
+
loss_c = self.calc_content_loss(
|
229 |
+
normal(Ics_feats[-1]), normal(content_feats[-1])
|
230 |
+
) + self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
|
231 |
# Style loss
|
232 |
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
|
233 |
for i in range(1, 5):
|
234 |
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
|
235 |
+
|
236 |
+
Icc = self.decode(self.transformer(content, mask, content, pos_c, pos_c))
|
237 |
+
Iss = self.decode(self.transformer(style, mask, style, pos_s, pos_s))
|
238 |
+
|
239 |
+
# Identity losses lambda 1
|
240 |
+
loss_lambda1 = self.calc_content_loss(
|
241 |
+
Icc, content_input
|
242 |
+
) + self.calc_content_loss(Iss, style_input)
|
243 |
+
|
244 |
+
# Identity losses lambda 2
|
245 |
+
Icc_feats = self.encode_with_intermediate(Icc)
|
246 |
+
Iss_feats = self.encode_with_intermediate(Iss)
|
247 |
+
loss_lambda2 = self.calc_content_loss(
|
248 |
+
Icc_feats[0], content_feats[0]
|
249 |
+
) + self.calc_content_loss(Iss_feats[0], style_feats[0])
|
250 |
for i in range(1, 5):
|
251 |
+
loss_lambda2 += self.calc_content_loss(
|
252 |
+
Icc_feats[i], content_feats[i]
|
253 |
+
) + self.calc_content_loss(Iss_feats[i], style_feats[i])
|
254 |
# Please select and comment out one of the following two sentences
|
255 |
+
return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 # train
|
256 |
+
# return Ics #test
|
StyleTransfer/{models β srcTransformer/Transformer_models}/decoder_iter_160000.pth
RENAMED
File without changes
|
StyleTransfer/{models β srcTransformer/Transformer_models}/embedding_iter_160000.pth
RENAMED
File without changes
|
StyleTransfer/{models β srcTransformer/Transformer_models}/transformer_iter_160000.pth
RENAMED
File without changes
|
StyleTransfer/{models β srcTransformer/Transformer_models}/vgg_normalised.pth
RENAMED
File without changes
|
StyleTransfer/{ViT_helper.py β srcTransformer/ViT_helper.py}
RENAMED
@@ -1,18 +1,30 @@
|
|
|
|
|
|
|
|
|
|
1 |
import torch
|
2 |
from torch import nn
|
|
|
3 |
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
'
|
|
|
|
|
|
|
|
|
|
|
11 |
"""
|
12 |
-
if drop_prob == 0. or not training:
|
13 |
return x
|
14 |
keep_prob = 1 - drop_prob
|
15 |
-
shape = (x.shape[0],) + (1,) * (
|
|
|
|
|
16 |
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
17 |
random_tensor.floor_() # binarize
|
18 |
output = x.div(keep_prob) * random_tensor
|
@@ -20,25 +32,26 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
|
|
20 |
|
21 |
|
22 |
class DropPath(nn.Module):
|
23 |
-
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
24 |
"""
|
25 |
-
|
|
|
|
|
|
|
|
|
26 |
super(DropPath, self).__init__()
|
27 |
self.drop_prob = drop_prob
|
28 |
|
29 |
def forward(self, x):
|
30 |
return drop_path(x, self.drop_prob, self.training)
|
31 |
|
32 |
-
from itertools import repeat
|
33 |
-
from torch._six import container_abcs
|
34 |
-
|
35 |
|
36 |
# From PyTorch internals
|
37 |
-
def _ntuple(n):
|
38 |
def parse(x):
|
39 |
if isinstance(x, container_abcs.Iterable):
|
40 |
return x
|
41 |
return tuple(repeat(x, n))
|
|
|
42 |
return parse
|
43 |
|
44 |
|
@@ -48,41 +61,41 @@ to_3tuple = _ntuple(3)
|
|
48 |
to_4tuple = _ntuple(4)
|
49 |
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
# Cut & paste from PyTorch official master until it's in a few official releases - RW
|
59 |
-
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
60 |
def norm_cdf(x):
|
61 |
# Computes standard normal cumulative distribution function
|
62 |
-
return (1. + math.erf(x / math.sqrt(2.))) / 2.
|
63 |
|
64 |
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
65 |
-
warnings.warn(
|
66 |
-
|
67 |
-
|
|
|
|
|
68 |
|
69 |
with torch.no_grad():
|
70 |
# Values are generated by using a truncated uniform distribution and
|
71 |
# then using the inverse CDF for the normal distribution.
|
72 |
# Get upper and lower cdf values
|
73 |
-
|
74 |
-
|
75 |
|
76 |
# Uniformly fill tensor with values from [l, u], then translate to
|
77 |
# [2l-1, 2u-1].
|
78 |
-
tensor.uniform_(2 *
|
79 |
|
80 |
# Use inverse cdf transform for normal distribution to get truncated
|
81 |
# standard normal
|
82 |
tensor.erfinv_()
|
83 |
|
84 |
# Transform to proper mean, std
|
85 |
-
tensor.mul_(std * math.sqrt(2.))
|
86 |
tensor.add_(mean)
|
87 |
|
88 |
# Clamp to ensure it's in the proper range
|
@@ -90,8 +103,13 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
|
|
90 |
return tensor
|
91 |
|
92 |
|
93 |
-
def trunc_normal_(
|
94 |
-
|
|
|
|
|
|
|
|
|
|
|
95 |
r"""Fills the input Tensor with values drawn from a truncated
|
96 |
normal distribution. The values are effectively drawn from the
|
97 |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
@@ -108,4 +126,4 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
|
|
108 |
>>> w = torch.empty(3, 5)
|
109 |
>>> nn.init.trunc_normal_(w)
|
110 |
"""
|
111 |
-
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
|
|
1 |
+
import math
|
2 |
+
import warnings
|
3 |
+
from itertools import repeat
|
4 |
+
|
5 |
import torch
|
6 |
from torch import nn
|
7 |
+
from torch._six import container_abcs
|
8 |
|
9 |
+
|
10 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
11 |
+
"""
|
12 |
+
Drop paths (Stochastic Depth) per sample (when applied in main
|
13 |
+
path of residual blocks). This is the same as the DropConnect impl
|
14 |
+
I created for EfficientNet, etc networks, however, the original name
|
15 |
+
is misleading as 'Drop Connect' is a different form of dropout in a
|
16 |
+
separate paper... See discussion:
|
17 |
+
https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
|
18 |
+
I've opted for changing the layer and argument names to 'drop path'
|
19 |
+
rather than mix DropConnect as a layer name and use 'survival rate'
|
20 |
+
as the argument.
|
21 |
"""
|
22 |
+
if drop_prob == 0.0 or not training:
|
23 |
return x
|
24 |
keep_prob = 1 - drop_prob
|
25 |
+
shape = (x.shape[0],) + (1,) * (
|
26 |
+
x.ndim - 1
|
27 |
+
) # work with diff dim tensors, not just 2D ConvNets
|
28 |
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
|
29 |
random_tensor.floor_() # binarize
|
30 |
output = x.div(keep_prob) * random_tensor
|
|
|
32 |
|
33 |
|
34 |
class DropPath(nn.Module):
|
|
|
35 |
"""
|
36 |
+
Drop paths (Stochastic Depth) per sample
|
37 |
+
(when applied in main path of residual blocks).
|
38 |
+
"""
|
39 |
+
|
40 |
+
def __init__(self, drop_prob: float = None):
|
41 |
super(DropPath, self).__init__()
|
42 |
self.drop_prob = drop_prob
|
43 |
|
44 |
def forward(self, x):
|
45 |
return drop_path(x, self.drop_prob, self.training)
|
46 |
|
|
|
|
|
|
|
47 |
|
48 |
# From PyTorch internals
|
49 |
+
def _ntuple(n: int):
|
50 |
def parse(x):
|
51 |
if isinstance(x, container_abcs.Iterable):
|
52 |
return x
|
53 |
return tuple(repeat(x, n))
|
54 |
+
|
55 |
return parse
|
56 |
|
57 |
|
|
|
61 |
to_4tuple = _ntuple(4)
|
62 |
|
63 |
|
64 |
+
def _no_grad_trunc_normal_(
|
65 |
+
tensor: torch.tensor, mean: float, std: float, a: float, b: float
|
66 |
+
):
|
67 |
+
# Cut & paste from PyTorch official master
|
68 |
+
# until it's in a few official releases - RW
|
69 |
+
# Method based on:
|
70 |
+
# https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
|
|
|
|
|
71 |
def norm_cdf(x):
|
72 |
# Computes standard normal cumulative distribution function
|
73 |
+
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
|
74 |
|
75 |
if (mean < a - 2 * std) or (mean > b + 2 * std):
|
76 |
+
warnings.warn(
|
77 |
+
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
|
78 |
+
"The distribution of values may be incorrect.",
|
79 |
+
stacklevel=2,
|
80 |
+
)
|
81 |
|
82 |
with torch.no_grad():
|
83 |
# Values are generated by using a truncated uniform distribution and
|
84 |
# then using the inverse CDF for the normal distribution.
|
85 |
# Get upper and lower cdf values
|
86 |
+
lower = norm_cdf((a - mean) / std)
|
87 |
+
upper = norm_cdf((b - mean) / std)
|
88 |
|
89 |
# Uniformly fill tensor with values from [l, u], then translate to
|
90 |
# [2l-1, 2u-1].
|
91 |
+
tensor.uniform_(2 * lower - 1, 2 * upper - 1)
|
92 |
|
93 |
# Use inverse cdf transform for normal distribution to get truncated
|
94 |
# standard normal
|
95 |
tensor.erfinv_()
|
96 |
|
97 |
# Transform to proper mean, std
|
98 |
+
tensor.mul_(std * math.sqrt(2.0))
|
99 |
tensor.add_(mean)
|
100 |
|
101 |
# Clamp to ensure it's in the proper range
|
|
|
103 |
return tensor
|
104 |
|
105 |
|
106 |
+
def trunc_normal_(
|
107 |
+
tensor: torch.tensor,
|
108 |
+
mean: float = 0.0,
|
109 |
+
std: float = 1.0,
|
110 |
+
a: float = -2.0,
|
111 |
+
b: float = 2.0,
|
112 |
+
):
|
113 |
r"""Fills the input Tensor with values drawn from a truncated
|
114 |
normal distribution. The values are effectively drawn from the
|
115 |
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
|
|
|
126 |
>>> w = torch.empty(3, 5)
|
127 |
>>> nn.init.trunc_normal_(w)
|
128 |
"""
|
129 |
+
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
|
StyleTransfer/srcTransformer/__init__.py
ADDED
File without changes
|
StyleTransfer/{function.py β srcTransformer/function.py}
RENAMED
@@ -4,35 +4,41 @@ import torch
|
|
4 |
def calc_mean_std(feat, eps=1e-5):
|
5 |
# eps is a small value added to the variance to avoid divide-by-zero.
|
6 |
size = feat.size()
|
7 |
-
assert
|
8 |
N, C = size[:2]
|
9 |
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
10 |
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
11 |
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
12 |
return feat_mean, feat_std
|
13 |
|
|
|
14 |
def calc_mean_std1(feat, eps=1e-5):
|
15 |
# eps is a small value added to the variance to avoid divide-by-zero.
|
16 |
size = feat.size()
|
17 |
# assert (len(size) == 4)
|
18 |
-
WH,N, C = size
|
19 |
feat_var = feat.var(dim=0) + eps
|
20 |
feat_std = feat_var.sqrt()
|
21 |
feat_mean = feat.mean(dim=0)
|
22 |
return feat_mean, feat_std
|
|
|
|
|
23 |
def normal(feat, eps=1e-5):
|
24 |
-
feat_mean, feat_std= calc_mean_std(feat, eps)
|
25 |
-
normalized=(feat-feat_mean)/feat_std
|
26 |
-
return normalized
|
|
|
|
|
27 |
def normal_style(feat, eps=1e-5):
|
28 |
-
feat_mean, feat_std= calc_mean_std1(feat, eps)
|
29 |
-
normalized=(feat-feat_mean)/feat_std
|
30 |
return normalized
|
31 |
|
|
|
32 |
def _calc_feat_flatten_mean_std(feat):
|
33 |
# takes 3D feat (C, H, W), return mean and std of array within channels
|
34 |
-
assert
|
35 |
-
assert
|
36 |
feat_flatten = feat.view(3, -1)
|
37 |
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
38 |
std = feat_flatten.std(dim=-1, keepdim=True)
|
@@ -49,25 +55,24 @@ def coral(source, target):
|
|
49 |
# Note: flatten -> f
|
50 |
|
51 |
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
52 |
-
source_f_norm = (
|
53 |
-
source_f
|
54 |
-
|
55 |
-
|
56 |
|
57 |
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
58 |
-
target_f_norm = (
|
59 |
-
target_f
|
60 |
-
|
61 |
-
|
62 |
|
63 |
source_f_norm_transfer = torch.mm(
|
64 |
_mat_sqrt(target_f_cov_eye),
|
65 |
-
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
|
66 |
-
source_f_norm)
|
67 |
)
|
68 |
|
69 |
-
source_f_transfer = source_f_norm_transfer *
|
70 |
-
|
71 |
-
|
72 |
|
73 |
return source_f_transfer.view(source.size())
|
|
|
4 |
def calc_mean_std(feat, eps=1e-5):
|
5 |
# eps is a small value added to the variance to avoid divide-by-zero.
|
6 |
size = feat.size()
|
7 |
+
assert len(size) == 4
|
8 |
N, C = size[:2]
|
9 |
feat_var = feat.view(N, C, -1).var(dim=2) + eps
|
10 |
feat_std = feat_var.sqrt().view(N, C, 1, 1)
|
11 |
feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
|
12 |
return feat_mean, feat_std
|
13 |
|
14 |
+
|
15 |
def calc_mean_std1(feat, eps=1e-5):
|
16 |
# eps is a small value added to the variance to avoid divide-by-zero.
|
17 |
size = feat.size()
|
18 |
# assert (len(size) == 4)
|
19 |
+
WH, N, C = size
|
20 |
feat_var = feat.var(dim=0) + eps
|
21 |
feat_std = feat_var.sqrt()
|
22 |
feat_mean = feat.mean(dim=0)
|
23 |
return feat_mean, feat_std
|
24 |
+
|
25 |
+
|
26 |
def normal(feat, eps=1e-5):
|
27 |
+
feat_mean, feat_std = calc_mean_std(feat, eps)
|
28 |
+
normalized = (feat - feat_mean) / feat_std
|
29 |
+
return normalized
|
30 |
+
|
31 |
+
|
32 |
def normal_style(feat, eps=1e-5):
|
33 |
+
feat_mean, feat_std = calc_mean_std1(feat, eps)
|
34 |
+
normalized = (feat - feat_mean) / feat_std
|
35 |
return normalized
|
36 |
|
37 |
+
|
38 |
def _calc_feat_flatten_mean_std(feat):
|
39 |
# takes 3D feat (C, H, W), return mean and std of array within channels
|
40 |
+
assert feat.size()[0] == 3
|
41 |
+
assert isinstance(feat, torch.FloatTensor)
|
42 |
feat_flatten = feat.view(3, -1)
|
43 |
mean = feat_flatten.mean(dim=-1, keepdim=True)
|
44 |
std = feat_flatten.std(dim=-1, keepdim=True)
|
|
|
55 |
# Note: flatten -> f
|
56 |
|
57 |
source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
|
58 |
+
source_f_norm = (
|
59 |
+
source_f - source_f_mean.expand_as(source_f)
|
60 |
+
) / source_f_std.expand_as(source_f)
|
61 |
+
source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
|
62 |
|
63 |
target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
|
64 |
+
target_f_norm = (
|
65 |
+
target_f - target_f_mean.expand_as(target_f)
|
66 |
+
) / target_f_std.expand_as(target_f)
|
67 |
+
target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
|
68 |
|
69 |
source_f_norm_transfer = torch.mm(
|
70 |
_mat_sqrt(target_f_cov_eye),
|
71 |
+
torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm),
|
|
|
72 |
)
|
73 |
|
74 |
+
source_f_transfer = source_f_norm_transfer * target_f_std.expand_as(
|
75 |
+
source_f_norm
|
76 |
+
) + target_f_mean.expand_as(source_f_norm)
|
77 |
|
78 |
return source_f_transfer.view(source.size())
|
StyleTransfer/{misc.py β srcTransformer/misc.py}
RENAMED
@@ -4,20 +4,21 @@ Misc functions, including distributed helpers.
|
|
4 |
|
5 |
Mostly copy-paste from torchvision references.
|
6 |
"""
|
|
|
7 |
import os
|
|
|
8 |
import subprocess
|
9 |
import time
|
10 |
from collections import defaultdict, deque
|
11 |
-
import
|
12 |
-
import pickle
|
13 |
-
from typing import Optional, List
|
14 |
|
15 |
import torch
|
16 |
import torch.distributed as dist
|
17 |
-
from torch import Tensor
|
18 |
|
19 |
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
20 |
import torchvision
|
|
|
|
|
21 |
if float(torchvision.__version__[:3]) < 0.7:
|
22 |
from torchvision.ops import _new_empty_tensor
|
23 |
from torchvision.ops.misc import _output_size
|
@@ -47,7 +48,7 @@ class SmoothedValue(object):
|
|
47 |
"""
|
48 |
if not is_dist_avail_and_initialized():
|
49 |
return
|
50 |
-
t = torch.tensor([self.count, self.total], dtype=torch.float64, device=
|
51 |
dist.barrier()
|
52 |
dist.all_reduce(t)
|
53 |
t = t.tolist()
|
@@ -82,7 +83,8 @@ class SmoothedValue(object):
|
|
82 |
avg=self.avg,
|
83 |
global_avg=self.global_avg,
|
84 |
max=self.max,
|
85 |
-
value=self.value
|
|
|
86 |
|
87 |
|
88 |
def all_gather(data):
|
@@ -116,7 +118,9 @@ def all_gather(data):
|
|
116 |
for _ in size_list:
|
117 |
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
118 |
if local_size != max_size:
|
119 |
-
padding = torch.empty(
|
|
|
|
|
120 |
tensor = torch.cat((tensor, padding), dim=0)
|
121 |
dist.all_gather(tensor_list, tensor)
|
122 |
|
@@ -172,15 +176,14 @@ class MetricLogger(object):
|
|
172 |
return self.meters[attr]
|
173 |
if attr in self.__dict__:
|
174 |
return self.__dict__[attr]
|
175 |
-
raise AttributeError(
|
176 |
-
type(self).__name__, attr)
|
|
|
177 |
|
178 |
def __str__(self):
|
179 |
loss_str = []
|
180 |
for name, meter in self.meters.items():
|
181 |
-
loss_str.append(
|
182 |
-
"{}: {}".format(name, str(meter))
|
183 |
-
)
|
184 |
return self.delimiter.join(loss_str)
|
185 |
|
186 |
def synchronize_between_processes(self):
|
@@ -193,31 +196,35 @@ class MetricLogger(object):
|
|
193 |
def log_every(self, iterable, print_freq, header=None):
|
194 |
i = 0
|
195 |
if not header:
|
196 |
-
header =
|
197 |
start_time = time.time()
|
198 |
end = time.time()
|
199 |
-
iter_time = SmoothedValue(fmt=
|
200 |
-
data_time = SmoothedValue(fmt=
|
201 |
-
space_fmt =
|
202 |
if torch.cuda.is_available():
|
203 |
-
log_msg = self.delimiter.join(
|
204 |
-
|
205 |
-
|
206 |
-
|
207 |
-
|
208 |
-
|
209 |
-
|
210 |
-
|
211 |
-
|
|
|
|
|
212 |
else:
|
213 |
-
log_msg = self.delimiter.join(
|
214 |
-
|
215 |
-
|
216 |
-
|
217 |
-
|
218 |
-
|
219 |
-
|
220 |
-
|
|
|
|
|
221 |
MB = 1024.0 * 1024.0
|
222 |
for obj in iterable:
|
223 |
data_time.update(time.time() - end)
|
@@ -227,38 +234,54 @@ class MetricLogger(object):
|
|
227 |
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
228 |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
229 |
if torch.cuda.is_available():
|
230 |
-
print(
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
235 |
else:
|
236 |
-
print(
|
237 |
-
|
238 |
-
|
239 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
240 |
i += 1
|
241 |
end = time.time()
|
242 |
total_time = time.time() - start_time
|
243 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
244 |
-
print(
|
245 |
-
|
|
|
|
|
|
|
246 |
|
247 |
|
248 |
def get_sha():
|
249 |
cwd = os.path.dirname(os.path.abspath(__file__))
|
250 |
|
251 |
def _run(command):
|
252 |
-
return subprocess.check_output(command, cwd=cwd).decode(
|
253 |
-
|
|
|
254 |
diff = "clean"
|
255 |
-
branch =
|
256 |
try:
|
257 |
-
sha = _run([
|
258 |
-
subprocess.check_output([
|
259 |
-
diff = _run([
|
260 |
diff = "has uncommited changes" if diff else "clean"
|
261 |
-
branch = _run([
|
262 |
except Exception:
|
263 |
pass
|
264 |
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
@@ -324,9 +347,9 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
324 |
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
325 |
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
326 |
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
327 |
-
m[: img.shape[1], :img.shape[2]] = False
|
328 |
else:
|
329 |
-
raise ValueError(
|
330 |
return NestedTensor(tensor, mask)
|
331 |
|
332 |
|
@@ -336,7 +359,9 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
|
|
336 |
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
337 |
max_size = []
|
338 |
for i in range(tensor_list[0].dim()):
|
339 |
-
max_size_i = torch.max(
|
|
|
|
|
340 |
max_size.append(max_size_i)
|
341 |
max_size = tuple(max_size)
|
342 |
|
@@ -348,11 +373,15 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
|
|
348 |
padded_masks = []
|
349 |
for img in tensor_list:
|
350 |
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
351 |
-
padded_img = torch.nn.functional.pad(
|
|
|
|
|
352 |
padded_imgs.append(padded_img)
|
353 |
|
354 |
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
355 |
-
padded_mask = torch.nn.functional.pad(
|
|
|
|
|
356 |
padded_masks.append(padded_mask.to(torch.bool))
|
357 |
|
358 |
tensor = torch.stack(padded_imgs)
|
@@ -366,10 +395,11 @@ def setup_for_distributed(is_master):
|
|
366 |
This function disables printing when not in master process
|
367 |
"""
|
368 |
import builtins as __builtin__
|
|
|
369 |
builtin_print = __builtin__.print
|
370 |
|
371 |
def print(*args, **kwargs):
|
372 |
-
force = kwargs.pop(
|
373 |
if is_master or force:
|
374 |
builtin_print(*args, **kwargs)
|
375 |
|
@@ -406,26 +436,31 @@ def save_on_master(*args, **kwargs):
|
|
406 |
|
407 |
|
408 |
def init_distributed_mode(args):
|
409 |
-
if
|
410 |
args.rank = int(os.environ["RANK"])
|
411 |
-
args.world_size = int(os.environ[
|
412 |
-
args.gpu = int(os.environ[
|
413 |
-
elif
|
414 |
-
args.rank = int(os.environ[
|
415 |
args.gpu = args.rank % torch.cuda.device_count()
|
416 |
else:
|
417 |
-
print(
|
418 |
args.distributed = False
|
419 |
return
|
420 |
|
421 |
args.distributed = True
|
422 |
|
423 |
torch.cuda.set_device(args.gpu)
|
424 |
-
args.dist_backend =
|
425 |
-
print(
|
426 |
-
args.rank, args.dist_url), flush=True
|
427 |
-
|
428 |
-
|
|
|
|
|
|
|
|
|
|
|
429 |
torch.distributed.barrier()
|
430 |
setup_for_distributed(args.rank == 0)
|
431 |
|
@@ -449,8 +484,14 @@ def accuracy(output, target, topk=(1,)):
|
|
449 |
return res
|
450 |
|
451 |
|
452 |
-
def interpolate(
|
453 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
454 |
"""
|
455 |
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
456 |
This will eventually be supported natively by PyTorch, and this
|
@@ -466,4 +507,6 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
|
|
466 |
output_shape = list(input.shape[:-2]) + list(output_shape)
|
467 |
return _new_empty_tensor(input, output_shape)
|
468 |
else:
|
469 |
-
return torchvision.ops.misc.interpolate(
|
|
|
|
|
|
4 |
|
5 |
Mostly copy-paste from torchvision references.
|
6 |
"""
|
7 |
+
import datetime
|
8 |
import os
|
9 |
+
import pickle
|
10 |
import subprocess
|
11 |
import time
|
12 |
from collections import defaultdict, deque
|
13 |
+
from typing import List, Optional
|
|
|
|
|
14 |
|
15 |
import torch
|
16 |
import torch.distributed as dist
|
|
|
17 |
|
18 |
# needed due to empty tensor bug in pytorch and torchvision 0.5
|
19 |
import torchvision
|
20 |
+
from torch import Tensor
|
21 |
+
|
22 |
if float(torchvision.__version__[:3]) < 0.7:
|
23 |
from torchvision.ops import _new_empty_tensor
|
24 |
from torchvision.ops.misc import _output_size
|
|
|
48 |
"""
|
49 |
if not is_dist_avail_and_initialized():
|
50 |
return
|
51 |
+
t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
|
52 |
dist.barrier()
|
53 |
dist.all_reduce(t)
|
54 |
t = t.tolist()
|
|
|
83 |
avg=self.avg,
|
84 |
global_avg=self.global_avg,
|
85 |
max=self.max,
|
86 |
+
value=self.value,
|
87 |
+
)
|
88 |
|
89 |
|
90 |
def all_gather(data):
|
|
|
118 |
for _ in size_list:
|
119 |
tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
|
120 |
if local_size != max_size:
|
121 |
+
padding = torch.empty(
|
122 |
+
size=(max_size - local_size,), dtype=torch.uint8, device="cuda"
|
123 |
+
)
|
124 |
tensor = torch.cat((tensor, padding), dim=0)
|
125 |
dist.all_gather(tensor_list, tensor)
|
126 |
|
|
|
176 |
return self.meters[attr]
|
177 |
if attr in self.__dict__:
|
178 |
return self.__dict__[attr]
|
179 |
+
raise AttributeError(
|
180 |
+
"'{}' object has no attribute '{}'".format(type(self).__name__, attr)
|
181 |
+
)
|
182 |
|
183 |
def __str__(self):
|
184 |
loss_str = []
|
185 |
for name, meter in self.meters.items():
|
186 |
+
loss_str.append("{}: {}".format(name, str(meter)))
|
|
|
|
|
187 |
return self.delimiter.join(loss_str)
|
188 |
|
189 |
def synchronize_between_processes(self):
|
|
|
196 |
def log_every(self, iterable, print_freq, header=None):
|
197 |
i = 0
|
198 |
if not header:
|
199 |
+
header = ""
|
200 |
start_time = time.time()
|
201 |
end = time.time()
|
202 |
+
iter_time = SmoothedValue(fmt="{avg:.4f}")
|
203 |
+
data_time = SmoothedValue(fmt="{avg:.4f}")
|
204 |
+
space_fmt = ":" + str(len(str(len(iterable)))) + "d"
|
205 |
if torch.cuda.is_available():
|
206 |
+
log_msg = self.delimiter.join(
|
207 |
+
[
|
208 |
+
header,
|
209 |
+
"[{0" + space_fmt + "}/{1}]",
|
210 |
+
"eta: {eta}",
|
211 |
+
"{meters}",
|
212 |
+
"time: {time}",
|
213 |
+
"data: {data}",
|
214 |
+
"max mem: {memory:.0f}",
|
215 |
+
]
|
216 |
+
)
|
217 |
else:
|
218 |
+
log_msg = self.delimiter.join(
|
219 |
+
[
|
220 |
+
header,
|
221 |
+
"[{0" + space_fmt + "}/{1}]",
|
222 |
+
"eta: {eta}",
|
223 |
+
"{meters}",
|
224 |
+
"time: {time}",
|
225 |
+
"data: {data}",
|
226 |
+
]
|
227 |
+
)
|
228 |
MB = 1024.0 * 1024.0
|
229 |
for obj in iterable:
|
230 |
data_time.update(time.time() - end)
|
|
|
234 |
eta_seconds = iter_time.global_avg * (len(iterable) - i)
|
235 |
eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
|
236 |
if torch.cuda.is_available():
|
237 |
+
print(
|
238 |
+
log_msg.format(
|
239 |
+
i,
|
240 |
+
len(iterable),
|
241 |
+
eta=eta_string,
|
242 |
+
meters=str(self),
|
243 |
+
time=str(iter_time),
|
244 |
+
data=str(data_time),
|
245 |
+
memory=torch.cuda.max_memory_allocated() / MB,
|
246 |
+
)
|
247 |
+
)
|
248 |
else:
|
249 |
+
print(
|
250 |
+
log_msg.format(
|
251 |
+
i,
|
252 |
+
len(iterable),
|
253 |
+
eta=eta_string,
|
254 |
+
meters=str(self),
|
255 |
+
time=str(iter_time),
|
256 |
+
data=str(data_time),
|
257 |
+
)
|
258 |
+
)
|
259 |
i += 1
|
260 |
end = time.time()
|
261 |
total_time = time.time() - start_time
|
262 |
total_time_str = str(datetime.timedelta(seconds=int(total_time)))
|
263 |
+
print(
|
264 |
+
"{} Total time: {} ({:.4f} s / it)".format(
|
265 |
+
header, total_time_str, total_time / len(iterable)
|
266 |
+
)
|
267 |
+
)
|
268 |
|
269 |
|
270 |
def get_sha():
|
271 |
cwd = os.path.dirname(os.path.abspath(__file__))
|
272 |
|
273 |
def _run(command):
|
274 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
275 |
+
|
276 |
+
sha = "N/A"
|
277 |
diff = "clean"
|
278 |
+
branch = "N/A"
|
279 |
try:
|
280 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
281 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
282 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
283 |
diff = "has uncommited changes" if diff else "clean"
|
284 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
285 |
except Exception:
|
286 |
pass
|
287 |
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
|
|
347 |
mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
|
348 |
for img, pad_img, m in zip(tensor_list, tensor, mask):
|
349 |
pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
|
350 |
+
m[: img.shape[1], : img.shape[2]] = False
|
351 |
else:
|
352 |
+
raise ValueError("not supported")
|
353 |
return NestedTensor(tensor, mask)
|
354 |
|
355 |
|
|
|
359 |
def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
|
360 |
max_size = []
|
361 |
for i in range(tensor_list[0].dim()):
|
362 |
+
max_size_i = torch.max(
|
363 |
+
torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
|
364 |
+
).to(torch.int64)
|
365 |
max_size.append(max_size_i)
|
366 |
max_size = tuple(max_size)
|
367 |
|
|
|
373 |
padded_masks = []
|
374 |
for img in tensor_list:
|
375 |
padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
|
376 |
+
padded_img = torch.nn.functional.pad(
|
377 |
+
img, (0, padding[2], 0, padding[1], 0, padding[0])
|
378 |
+
)
|
379 |
padded_imgs.append(padded_img)
|
380 |
|
381 |
m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
|
382 |
+
padded_mask = torch.nn.functional.pad(
|
383 |
+
m, (0, padding[2], 0, padding[1]), "constant", 1
|
384 |
+
)
|
385 |
padded_masks.append(padded_mask.to(torch.bool))
|
386 |
|
387 |
tensor = torch.stack(padded_imgs)
|
|
|
395 |
This function disables printing when not in master process
|
396 |
"""
|
397 |
import builtins as __builtin__
|
398 |
+
|
399 |
builtin_print = __builtin__.print
|
400 |
|
401 |
def print(*args, **kwargs):
|
402 |
+
force = kwargs.pop("force", False)
|
403 |
if is_master or force:
|
404 |
builtin_print(*args, **kwargs)
|
405 |
|
|
|
436 |
|
437 |
|
438 |
def init_distributed_mode(args):
|
439 |
+
if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
|
440 |
args.rank = int(os.environ["RANK"])
|
441 |
+
args.world_size = int(os.environ["WORLD_SIZE"])
|
442 |
+
args.gpu = int(os.environ["LOCAL_RANK"])
|
443 |
+
elif "SLURM_PROCID" in os.environ:
|
444 |
+
args.rank = int(os.environ["SLURM_PROCID"])
|
445 |
args.gpu = args.rank % torch.cuda.device_count()
|
446 |
else:
|
447 |
+
print("Not using distributed mode")
|
448 |
args.distributed = False
|
449 |
return
|
450 |
|
451 |
args.distributed = True
|
452 |
|
453 |
torch.cuda.set_device(args.gpu)
|
454 |
+
args.dist_backend = "nccl"
|
455 |
+
print(
|
456 |
+
"| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
|
457 |
+
)
|
458 |
+
torch.distributed.init_process_group(
|
459 |
+
backend=args.dist_backend,
|
460 |
+
init_method=args.dist_url,
|
461 |
+
world_size=args.world_size,
|
462 |
+
rank=args.rank,
|
463 |
+
)
|
464 |
torch.distributed.barrier()
|
465 |
setup_for_distributed(args.rank == 0)
|
466 |
|
|
|
484 |
return res
|
485 |
|
486 |
|
487 |
+
def interpolate(
|
488 |
+
input: torch.tensor,
|
489 |
+
size: List[int] = None,
|
490 |
+
scale_factor: float = None,
|
491 |
+
mode: str = "nearest",
|
492 |
+
align_corners: bool = None,
|
493 |
+
) -> torch.tensor:
|
494 |
+
|
495 |
"""
|
496 |
Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
|
497 |
This will eventually be supported natively by PyTorch, and this
|
|
|
507 |
output_shape = list(input.shape[:-2]) + list(output_shape)
|
508 |
return _new_empty_tensor(input, output_shape)
|
509 |
else:
|
510 |
+
return torchvision.ops.misc.interpolate(
|
511 |
+
input, size, scale_factor, mode, align_corners
|
512 |
+
)
|
StyleTransfer/{transformer.py β srcTransformer/transformer.py}
RENAMED
@@ -1,40 +1,59 @@
|
|
1 |
import copy
|
2 |
-
|
|
|
3 |
|
|
|
4 |
import torch
|
5 |
import torch.nn.functional as F
|
6 |
-
from torch import
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import os
|
10 |
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
|
12 |
-
class Transformer(nn.Module):
|
13 |
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
super().__init__()
|
19 |
|
20 |
-
encoder_layer = TransformerEncoderLayer(
|
21 |
-
|
|
|
22 |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
23 |
-
self.encoder_c = TransformerEncoder(
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
28 |
decoder_norm = nn.LayerNorm(d_model)
|
29 |
-
self.decoder = TransformerDecoder(
|
30 |
-
|
|
|
|
|
|
|
|
|
31 |
|
32 |
self._reset_parameters()
|
33 |
|
34 |
self.d_model = d_model
|
35 |
self.nhead = nhead
|
36 |
|
37 |
-
self.new_ps = nn.Conv2d(512
|
38 |
self.averagepooling = nn.AdaptiveAvgPool2d(18)
|
39 |
|
40 |
def _reset_parameters(self):
|
@@ -42,54 +61,64 @@ class Transformer(nn.Module):
|
|
42 |
if p.dim() > 1:
|
43 |
nn.init.xavier_uniform_(p)
|
44 |
|
45 |
-
def forward(self, style, mask
|
46 |
|
47 |
# content-aware positional embedding
|
48 |
-
content_pool = self.averagepooling(content)
|
49 |
pos_c = self.new_ps(content_pool)
|
50 |
-
pos_embed_c = F.interpolate(pos_c, mode=
|
51 |
|
52 |
-
|
53 |
style = style.flatten(2).permute(2, 0, 1)
|
54 |
if pos_embed_s is not None:
|
55 |
pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
|
56 |
-
|
57 |
content = content.flatten(2).permute(2, 0, 1)
|
58 |
if pos_embed_c is not None:
|
59 |
pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
|
60 |
-
|
61 |
-
|
62 |
style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
|
63 |
content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
|
64 |
-
hs = self.decoder(
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
69 |
H = int(np.sqrt(N))
|
70 |
hs = hs.permute(1, 2, 0)
|
71 |
-
hs = hs.view(B, C, -1,H)
|
72 |
|
73 |
return hs
|
74 |
|
75 |
|
76 |
class TransformerEncoder(nn.Module):
|
77 |
-
|
78 |
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
super().__init__()
|
80 |
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
self.num_layers = num_layers
|
82 |
self.norm = norm
|
83 |
|
84 |
-
def forward(
|
85 |
-
|
86 |
-
|
87 |
-
|
|
|
|
|
|
|
88 |
output = src
|
89 |
-
|
90 |
for layer in self.layers:
|
91 |
-
output = layer(
|
92 |
-
|
|
|
|
|
|
|
|
|
93 |
|
94 |
if self.norm is not None:
|
95 |
output = self.norm(output)
|
@@ -98,7 +127,6 @@ class TransformerEncoder(nn.Module):
|
|
98 |
|
99 |
|
100 |
class TransformerDecoder(nn.Module):
|
101 |
-
|
102 |
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
103 |
super().__init__()
|
104 |
self.layers = _get_clones(decoder_layer, num_layers)
|
@@ -106,23 +134,32 @@ class TransformerDecoder(nn.Module):
|
|
106 |
self.norm = norm
|
107 |
self.return_intermediate = return_intermediate
|
108 |
|
109 |
-
def forward(
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
|
|
|
|
|
|
|
|
116 |
output = tgt
|
117 |
|
118 |
intermediate = []
|
119 |
|
120 |
for layer in self.layers:
|
121 |
-
output = layer(
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
|
|
|
|
|
|
|
|
|
|
126 |
if self.return_intermediate:
|
127 |
intermediate.append(self.norm(output))
|
128 |
|
@@ -139,9 +176,15 @@ class TransformerDecoder(nn.Module):
|
|
139 |
|
140 |
|
141 |
class TransformerEncoderLayer(nn.Module):
|
142 |
-
|
143 |
-
|
144 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
145 |
super().__init__()
|
146 |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
147 |
# Implementation of Feedforward model
|
@@ -160,16 +203,19 @@ class TransformerEncoderLayer(nn.Module):
|
|
160 |
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
161 |
return tensor if pos is None else tensor + pos
|
162 |
|
163 |
-
def forward_post(
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
|
|
|
|
168 |
q = k = self.with_pos_embed(src, pos)
|
169 |
# q = k = src
|
170 |
# print(q.size(),k.size(),src.size())
|
171 |
-
src2 = self.self_attn(
|
172 |
-
|
|
|
173 |
src = src + self.dropout1(src2)
|
174 |
src = self.norm1(src)
|
175 |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
@@ -177,33 +223,46 @@ class TransformerEncoderLayer(nn.Module):
|
|
177 |
src = self.norm2(src)
|
178 |
return src
|
179 |
|
180 |
-
def forward_pre(
|
181 |
-
|
182 |
-
|
183 |
-
|
|
|
|
|
|
|
184 |
src2 = self.norm1(src)
|
185 |
q = k = self.with_pos_embed(src2, pos)
|
186 |
-
src2 = self.self_attn(
|
187 |
-
|
|
|
188 |
src = src + self.dropout1(src2)
|
189 |
src2 = self.norm2(src)
|
190 |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
191 |
src = src + self.dropout2(src2)
|
192 |
return src
|
193 |
|
194 |
-
def forward(
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
198 |
if self.normalize_before:
|
199 |
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
200 |
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
201 |
|
202 |
|
203 |
class TransformerDecoderLayer(nn.Module):
|
204 |
-
|
205 |
-
|
206 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
207 |
super().__init__()
|
208 |
# d_model embedding dim
|
209 |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
@@ -226,28 +285,35 @@ class TransformerDecoderLayer(nn.Module):
|
|
226 |
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
227 |
return tensor if pos is None else tensor + pos
|
228 |
|
229 |
-
def forward_post(
|
230 |
-
|
231 |
-
|
232 |
-
|
233 |
-
|
234 |
-
|
235 |
-
|
|
|
|
|
|
|
|
|
236 |
|
237 |
-
|
238 |
q = self.with_pos_embed(tgt, query_pos)
|
239 |
k = self.with_pos_embed(memory, pos)
|
240 |
-
v = memory
|
241 |
-
|
242 |
-
tgt2 = self.self_attn(
|
243 |
-
|
244 |
-
|
|
|
245 |
tgt = tgt + self.dropout1(tgt2)
|
246 |
tgt = self.norm1(tgt)
|
247 |
-
tgt2 = self.multihead_attn(
|
248 |
-
|
249 |
-
|
250 |
-
|
|
|
|
|
|
|
251 |
tgt = tgt + self.dropout2(tgt2)
|
252 |
tgt = self.norm2(tgt)
|
253 |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
@@ -255,24 +321,32 @@ class TransformerDecoderLayer(nn.Module):
|
|
255 |
tgt = self.norm3(tgt)
|
256 |
return tgt
|
257 |
|
258 |
-
def forward_pre(
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
|
263 |
-
|
264 |
-
|
|
|
|
|
|
|
|
|
265 |
tgt2 = self.norm1(tgt)
|
266 |
q = k = self.with_pos_embed(tgt2, query_pos)
|
267 |
-
tgt2 = self.self_attn(
|
268 |
-
|
|
|
269 |
|
270 |
tgt = tgt + self.dropout1(tgt2)
|
271 |
tgt2 = self.norm2(tgt)
|
272 |
-
tgt2 = self.multihead_attn(
|
273 |
-
|
274 |
-
|
275 |
-
|
|
|
|
|
|
|
276 |
|
277 |
tgt = tgt + self.dropout2(tgt2)
|
278 |
tgt2 = self.norm3(tgt)
|
@@ -280,18 +354,38 @@ class TransformerDecoderLayer(nn.Module):
|
|
280 |
tgt = tgt + self.dropout3(tgt2)
|
281 |
return tgt
|
282 |
|
283 |
-
def forward(
|
284 |
-
|
285 |
-
|
286 |
-
|
287 |
-
|
288 |
-
|
289 |
-
|
|
|
|
|
|
|
|
|
290 |
if self.normalize_before:
|
291 |
-
return self.forward_pre(
|
292 |
-
|
293 |
-
|
294 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
295 |
|
296 |
|
297 |
def _get_clones(module, N):
|
@@ -319,4 +413,4 @@ def _get_activation_fn(activation):
|
|
319 |
return F.gelu
|
320 |
if activation == "glu":
|
321 |
return F.glu
|
322 |
-
raise RuntimeError(
|
|
|
1 |
import copy
|
2 |
+
import os
|
3 |
+
from typing import Optional
|
4 |
|
5 |
+
import numpy as np
|
6 |
import torch
|
7 |
import torch.nn.functional as F
|
8 |
+
from torch import Tensor, nn
|
9 |
+
|
|
|
|
|
10 |
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
11 |
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
|
|
|
12 |
|
13 |
+
|
14 |
+
class Transformer(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
d_model=512,
|
18 |
+
nhead=8,
|
19 |
+
num_encoder_layers=3,
|
20 |
+
num_decoder_layers=3,
|
21 |
+
dim_feedforward=2048,
|
22 |
+
dropout=0.1,
|
23 |
+
activation="relu",
|
24 |
+
normalize_before=False,
|
25 |
+
return_intermediate_dec=False,
|
26 |
+
):
|
27 |
super().__init__()
|
28 |
|
29 |
+
encoder_layer = TransformerEncoderLayer(
|
30 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
31 |
+
)
|
32 |
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
33 |
+
self.encoder_c = TransformerEncoder(
|
34 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
35 |
+
)
|
36 |
+
self.encoder_s = TransformerEncoder(
|
37 |
+
encoder_layer, num_encoder_layers, encoder_norm
|
38 |
+
)
|
39 |
+
|
40 |
+
decoder_layer = TransformerDecoderLayer(
|
41 |
+
d_model, nhead, dim_feedforward, dropout, activation, normalize_before
|
42 |
+
)
|
43 |
decoder_norm = nn.LayerNorm(d_model)
|
44 |
+
self.decoder = TransformerDecoder(
|
45 |
+
decoder_layer,
|
46 |
+
num_decoder_layers,
|
47 |
+
decoder_norm,
|
48 |
+
return_intermediate=return_intermediate_dec,
|
49 |
+
)
|
50 |
|
51 |
self._reset_parameters()
|
52 |
|
53 |
self.d_model = d_model
|
54 |
self.nhead = nhead
|
55 |
|
56 |
+
self.new_ps = nn.Conv2d(512, 512, (1, 1))
|
57 |
self.averagepooling = nn.AdaptiveAvgPool2d(18)
|
58 |
|
59 |
def _reset_parameters(self):
|
|
|
61 |
if p.dim() > 1:
|
62 |
nn.init.xavier_uniform_(p)
|
63 |
|
64 |
+
def forward(self, style, mask, content, pos_embed_c, pos_embed_s):
|
65 |
|
66 |
# content-aware positional embedding
|
67 |
+
content_pool = self.averagepooling(content)
|
68 |
pos_c = self.new_ps(content_pool)
|
69 |
+
pos_embed_c = F.interpolate(pos_c, mode="bilinear", size=style.shape[-2:])
|
70 |
|
71 |
+
# flatten NxCxHxW to HWxNxC
|
72 |
style = style.flatten(2).permute(2, 0, 1)
|
73 |
if pos_embed_s is not None:
|
74 |
pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
|
75 |
+
|
76 |
content = content.flatten(2).permute(2, 0, 1)
|
77 |
if pos_embed_c is not None:
|
78 |
pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
|
79 |
+
|
|
|
80 |
style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
|
81 |
content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
|
82 |
+
hs = self.decoder(
|
83 |
+
content,
|
84 |
+
style,
|
85 |
+
memory_key_padding_mask=mask,
|
86 |
+
pos=pos_embed_s,
|
87 |
+
query_pos=pos_embed_c,
|
88 |
+
)[0]
|
89 |
+
|
90 |
+
# HWxNxC to NxCxHxW to
|
91 |
+
N, B, C = hs.shape
|
92 |
H = int(np.sqrt(N))
|
93 |
hs = hs.permute(1, 2, 0)
|
94 |
+
hs = hs.view(B, C, -1, H)
|
95 |
|
96 |
return hs
|
97 |
|
98 |
|
99 |
class TransformerEncoder(nn.Module):
|
|
|
100 |
def __init__(self, encoder_layer, num_layers, norm=None):
|
101 |
super().__init__()
|
102 |
self.layers = _get_clones(encoder_layer, num_layers)
|
103 |
self.num_layers = num_layers
|
104 |
self.norm = norm
|
105 |
|
106 |
+
def forward(
|
107 |
+
self,
|
108 |
+
src,
|
109 |
+
mask: Optional[Tensor] = None,
|
110 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
111 |
+
pos: Optional[Tensor] = None,
|
112 |
+
):
|
113 |
output = src
|
114 |
+
|
115 |
for layer in self.layers:
|
116 |
+
output = layer(
|
117 |
+
output,
|
118 |
+
src_mask=mask,
|
119 |
+
src_key_padding_mask=src_key_padding_mask,
|
120 |
+
pos=pos,
|
121 |
+
)
|
122 |
|
123 |
if self.norm is not None:
|
124 |
output = self.norm(output)
|
|
|
127 |
|
128 |
|
129 |
class TransformerDecoder(nn.Module):
|
|
|
130 |
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
131 |
super().__init__()
|
132 |
self.layers = _get_clones(decoder_layer, num_layers)
|
|
|
134 |
self.norm = norm
|
135 |
self.return_intermediate = return_intermediate
|
136 |
|
137 |
+
def forward(
|
138 |
+
self,
|
139 |
+
tgt,
|
140 |
+
memory,
|
141 |
+
tgt_mask: Optional[Tensor] = None,
|
142 |
+
memory_mask: Optional[Tensor] = None,
|
143 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
144 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
145 |
+
pos: Optional[Tensor] = None,
|
146 |
+
query_pos: Optional[Tensor] = None,
|
147 |
+
):
|
148 |
output = tgt
|
149 |
|
150 |
intermediate = []
|
151 |
|
152 |
for layer in self.layers:
|
153 |
+
output = layer(
|
154 |
+
output,
|
155 |
+
memory,
|
156 |
+
tgt_mask=tgt_mask,
|
157 |
+
memory_mask=memory_mask,
|
158 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
159 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
160 |
+
pos=pos,
|
161 |
+
query_pos=query_pos,
|
162 |
+
)
|
163 |
if self.return_intermediate:
|
164 |
intermediate.append(self.norm(output))
|
165 |
|
|
|
176 |
|
177 |
|
178 |
class TransformerEncoderLayer(nn.Module):
|
179 |
+
def __init__(
|
180 |
+
self,
|
181 |
+
d_model,
|
182 |
+
nhead,
|
183 |
+
dim_feedforward=2048,
|
184 |
+
dropout=0.1,
|
185 |
+
activation="relu",
|
186 |
+
normalize_before=False,
|
187 |
+
):
|
188 |
super().__init__()
|
189 |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
190 |
# Implementation of Feedforward model
|
|
|
203 |
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
204 |
return tensor if pos is None else tensor + pos
|
205 |
|
206 |
+
def forward_post(
|
207 |
+
self,
|
208 |
+
src,
|
209 |
+
src_mask: Optional[Tensor] = None,
|
210 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
211 |
+
pos: Optional[Tensor] = None,
|
212 |
+
):
|
213 |
q = k = self.with_pos_embed(src, pos)
|
214 |
# q = k = src
|
215 |
# print(q.size(),k.size(),src.size())
|
216 |
+
src2 = self.self_attn(
|
217 |
+
q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
218 |
+
)[0]
|
219 |
src = src + self.dropout1(src2)
|
220 |
src = self.norm1(src)
|
221 |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
|
|
223 |
src = self.norm2(src)
|
224 |
return src
|
225 |
|
226 |
+
def forward_pre(
|
227 |
+
self,
|
228 |
+
src,
|
229 |
+
src_mask: Optional[Tensor] = None,
|
230 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
231 |
+
pos: Optional[Tensor] = None,
|
232 |
+
):
|
233 |
src2 = self.norm1(src)
|
234 |
q = k = self.with_pos_embed(src2, pos)
|
235 |
+
src2 = self.self_attn(
|
236 |
+
q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
|
237 |
+
)[0]
|
238 |
src = src + self.dropout1(src2)
|
239 |
src2 = self.norm2(src)
|
240 |
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
241 |
src = src + self.dropout2(src2)
|
242 |
return src
|
243 |
|
244 |
+
def forward(
|
245 |
+
self,
|
246 |
+
src,
|
247 |
+
src_mask: Optional[Tensor] = None,
|
248 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
249 |
+
pos: Optional[Tensor] = None,
|
250 |
+
):
|
251 |
if self.normalize_before:
|
252 |
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
253 |
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
254 |
|
255 |
|
256 |
class TransformerDecoderLayer(nn.Module):
|
257 |
+
def __init__(
|
258 |
+
self,
|
259 |
+
d_model,
|
260 |
+
nhead,
|
261 |
+
dim_feedforward=2048,
|
262 |
+
dropout=0.1,
|
263 |
+
activation="relu",
|
264 |
+
normalize_before=False,
|
265 |
+
):
|
266 |
super().__init__()
|
267 |
# d_model embedding dim
|
268 |
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
|
285 |
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
286 |
return tensor if pos is None else tensor + pos
|
287 |
|
288 |
+
def forward_post(
|
289 |
+
self,
|
290 |
+
tgt,
|
291 |
+
memory,
|
292 |
+
tgt_mask: Optional[Tensor] = None,
|
293 |
+
memory_mask: Optional[Tensor] = None,
|
294 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
295 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
296 |
+
pos: Optional[Tensor] = None,
|
297 |
+
query_pos: Optional[Tensor] = None,
|
298 |
+
):
|
299 |
|
|
|
300 |
q = self.with_pos_embed(tgt, query_pos)
|
301 |
k = self.with_pos_embed(memory, pos)
|
302 |
+
v = memory
|
303 |
+
|
304 |
+
tgt2 = self.self_attn(
|
305 |
+
q, k, v, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
306 |
+
)[0]
|
307 |
+
|
308 |
tgt = tgt + self.dropout1(tgt2)
|
309 |
tgt = self.norm1(tgt)
|
310 |
+
tgt2 = self.multihead_attn(
|
311 |
+
query=self.with_pos_embed(tgt, query_pos),
|
312 |
+
key=self.with_pos_embed(memory, pos),
|
313 |
+
value=memory,
|
314 |
+
attn_mask=memory_mask,
|
315 |
+
key_padding_mask=memory_key_padding_mask,
|
316 |
+
)[0]
|
317 |
tgt = tgt + self.dropout2(tgt2)
|
318 |
tgt = self.norm2(tgt)
|
319 |
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
|
|
321 |
tgt = self.norm3(tgt)
|
322 |
return tgt
|
323 |
|
324 |
+
def forward_pre(
|
325 |
+
self,
|
326 |
+
tgt,
|
327 |
+
memory,
|
328 |
+
tgt_mask: Optional[Tensor] = None,
|
329 |
+
memory_mask: Optional[Tensor] = None,
|
330 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
331 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
332 |
+
pos: Optional[Tensor] = None,
|
333 |
+
query_pos: Optional[Tensor] = None,
|
334 |
+
):
|
335 |
tgt2 = self.norm1(tgt)
|
336 |
q = k = self.with_pos_embed(tgt2, query_pos)
|
337 |
+
tgt2 = self.self_attn(
|
338 |
+
q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
|
339 |
+
)[0]
|
340 |
|
341 |
tgt = tgt + self.dropout1(tgt2)
|
342 |
tgt2 = self.norm2(tgt)
|
343 |
+
tgt2 = self.multihead_attn(
|
344 |
+
query=self.with_pos_embed(tgt2, query_pos),
|
345 |
+
key=self.with_pos_embed(memory, pos),
|
346 |
+
value=memory,
|
347 |
+
attn_mask=memory_mask,
|
348 |
+
key_padding_mask=memory_key_padding_mask,
|
349 |
+
)[0]
|
350 |
|
351 |
tgt = tgt + self.dropout2(tgt2)
|
352 |
tgt2 = self.norm3(tgt)
|
|
|
354 |
tgt = tgt + self.dropout3(tgt2)
|
355 |
return tgt
|
356 |
|
357 |
+
def forward(
|
358 |
+
self,
|
359 |
+
tgt,
|
360 |
+
memory,
|
361 |
+
tgt_mask: Optional[Tensor] = None,
|
362 |
+
memory_mask: Optional[Tensor] = None,
|
363 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
364 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
365 |
+
pos: Optional[Tensor] = None,
|
366 |
+
query_pos: Optional[Tensor] = None,
|
367 |
+
):
|
368 |
if self.normalize_before:
|
369 |
+
return self.forward_pre(
|
370 |
+
tgt,
|
371 |
+
memory,
|
372 |
+
tgt_mask,
|
373 |
+
memory_mask,
|
374 |
+
tgt_key_padding_mask,
|
375 |
+
memory_key_padding_mask,
|
376 |
+
pos,
|
377 |
+
query_pos,
|
378 |
+
)
|
379 |
+
return self.forward_post(
|
380 |
+
tgt,
|
381 |
+
memory,
|
382 |
+
tgt_mask,
|
383 |
+
memory_mask,
|
384 |
+
tgt_key_padding_mask,
|
385 |
+
memory_key_padding_mask,
|
386 |
+
pos,
|
387 |
+
query_pos,
|
388 |
+
)
|
389 |
|
390 |
|
391 |
def _get_clones(module, N):
|
|
|
413 |
return F.gelu
|
414 |
if activation == "glu":
|
415 |
return F.glu
|
416 |
+
raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
|
StyleTransfer/styleTransfer.py
CHANGED
@@ -1,115 +1,181 @@
|
|
1 |
-
from PIL import Image
|
2 |
import numpy as np
|
|
|
|
|
|
|
|
|
|
|
3 |
import torch
|
4 |
-
print(torch.cuda.is_available())
|
5 |
import torch.nn as nn
|
|
|
6 |
from torchvision import transforms
|
7 |
-
import StyleTransfer.transformer as transformer
|
8 |
-
import StyleTransfer.StyTR as StyTR
|
9 |
-
from collections import OrderedDict
|
10 |
-
import tensorflow_hub as tfhub
|
11 |
-
import tensorflow as tf
|
12 |
-
import paddlehub as phub
|
13 |
-
import os
|
14 |
-
|
15 |
-
############################################# TRANSFORMER ############################################
|
16 |
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
transform_list.append(transforms.ToTensor())
|
22 |
transform = transforms.Compose(transform_list)
|
23 |
return transform
|
24 |
|
25 |
-
|
26 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
transform_list.append(transforms.ToTensor())
|
28 |
transform = transforms.Compose(transform_list)
|
29 |
return transform
|
30 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
decoder = StyTR.decoder
|
45 |
-
Trans = transformer.Transformer()
|
46 |
-
embedding = StyTR.PatchEmbed()
|
47 |
decoder.eval()
|
48 |
Trans.eval()
|
49 |
vgg.eval()
|
50 |
|
51 |
-
new_state_dict = OrderedDict()
|
52 |
state_dict = torch.load(decoder_path)
|
53 |
decoder.load_state_dict(state_dict)
|
54 |
|
55 |
-
new_state_dict = OrderedDict()
|
56 |
state_dict = torch.load(Trans_path)
|
57 |
Trans.load_state_dict(state_dict)
|
58 |
|
59 |
-
new_state_dict = OrderedDict()
|
60 |
state_dict = torch.load(embedding_path)
|
61 |
embedding.load_state_dict(state_dict)
|
62 |
|
63 |
-
network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
|
64 |
network.eval()
|
65 |
-
|
66 |
-
|
|
|
67 |
|
68 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
69 |
network.to(device)
|
70 |
-
content = content_tf(content_img.convert("RGB"))
|
71 |
style = style_tf(style_img.convert("RGB"))
|
72 |
style = style.to(device).unsqueeze(0)
|
73 |
content = content.to(device).unsqueeze(0)
|
74 |
with torch.no_grad():
|
75 |
-
output= network(content,style)
|
76 |
output = output[0].cpu().squeeze()
|
77 |
-
output =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
return Image.fromarray(output)
|
79 |
-
|
80 |
-
############################################## STYLE-FAST #############################################
|
81 |
-
style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
|
82 |
|
83 |
-
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
output = style_transfer_model(content_image, style_image)
|
87 |
stylized_image = output[0]
|
88 |
return Image.fromarray(np.uint8(stylized_image[0] * 255))
|
89 |
|
90 |
-
|
91 |
-
|
|
|
92 |
stylepro_artistic = phub.Module(name="stylepro_artistic")
|
93 |
-
def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
|
94 |
-
print('line92')
|
95 |
-
result = stylepro_artistic.style_transfer(
|
96 |
-
images=[{
|
97 |
-
'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
|
98 |
-
'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
|
99 |
-
alpha=0.8)
|
100 |
-
print('line97')
|
101 |
-
return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
|
102 |
-
|
103 |
-
def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
|
104 |
-
if choice =="Style Transformer":
|
105 |
-
output = StyleTransformer(content_image,style_image)
|
106 |
-
elif choice =="Style FAST":
|
107 |
-
output = StyleFAST(content_image,style_image)
|
108 |
-
elif choice =="Style Projection":
|
109 |
-
output = StyleProjection(content_image,style_image)
|
110 |
-
else:
|
111 |
-
output = content_image
|
112 |
-
return output
|
113 |
-
|
114 |
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
+
import paddlehub as phub
|
3 |
+
import StyleTransfer.srcTransformer.StyTR as StyTR
|
4 |
+
import StyleTransfer.srcTransformer.transformer as transformer
|
5 |
+
import tensorflow as tf
|
6 |
+
import tensorflow_hub as tfhub
|
7 |
import torch
|
|
|
8 |
import torch.nn as nn
|
9 |
+
from PIL import Image
|
10 |
from torchvision import transforms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
11 |
|
12 |
+
# TRANSFORMER
|
13 |
+
|
14 |
+
vgg_path = "StyleTransfer/srcTransformer/Transformer_models/vgg_normalised.pth"
|
15 |
+
decoder_path = "StyleTransfer/srcTransformer/Transformer_models/decoder_iter_160000.pth"
|
16 |
+
Trans_path = (
|
17 |
+
"StyleTransfer/srcTransformer/Transformer_models/transformer_iter_160000.pth"
|
18 |
+
)
|
19 |
+
embedding_path = (
|
20 |
+
"StyleTransfer/srcTransformer/Transformer_models/embedding_iter_160000.pth"
|
21 |
+
)
|
22 |
+
|
23 |
+
|
24 |
+
def style_transform(h, w):
|
25 |
+
"""
|
26 |
+
This function creates a transformation for the style image,
|
27 |
+
that crops it and formats it into a tensor.
|
28 |
+
|
29 |
+
Parameters:
|
30 |
+
h = height
|
31 |
+
w = width
|
32 |
+
Return:
|
33 |
+
transform = transformation pipeline
|
34 |
+
"""
|
35 |
+
transform_list = []
|
36 |
+
transform_list.append(transforms.CenterCrop((h, w)))
|
37 |
transform_list.append(transforms.ToTensor())
|
38 |
transform = transforms.Compose(transform_list)
|
39 |
return transform
|
40 |
|
41 |
+
|
42 |
+
def content_transform():
|
43 |
+
"""
|
44 |
+
This function simply creates a transformation pipeline,
|
45 |
+
that formats the content image into a tensor.
|
46 |
+
|
47 |
+
Returns:
|
48 |
+
transform = the transformation pipeline
|
49 |
+
"""
|
50 |
+
transform_list = []
|
51 |
transform_list.append(transforms.ToTensor())
|
52 |
transform = transforms.Compose(transform_list)
|
53 |
return transform
|
54 |
|
55 |
+
|
56 |
+
# This loads the network architecture already at building time
|
57 |
+
vgg = StyTR.vgg
|
58 |
+
vgg.load_state_dict(torch.load(vgg_path))
|
59 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
60 |
+
decoder = StyTR.decoder
|
61 |
+
Trans = transformer.Transformer()
|
62 |
+
embedding = StyTR.PatchEmbed()
|
63 |
+
# The (square) shape of the content and style image is fixed
|
64 |
+
content_size = 640
|
65 |
+
style_size = 640
|
66 |
+
|
67 |
+
|
68 |
def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
|
69 |
+
"""
|
70 |
+
This function creates the Transformer network and applies it on
|
71 |
+
a content and style image to create a styled image.
|
72 |
+
|
73 |
+
Parameters:
|
74 |
+
content_img = the image with the content
|
75 |
+
style_img = the image with the style/pattern
|
76 |
+
Returns:
|
77 |
+
output = an image that is a combination of both
|
78 |
+
"""
|
79 |
+
|
|
|
|
|
|
|
|
|
80 |
decoder.eval()
|
81 |
Trans.eval()
|
82 |
vgg.eval()
|
83 |
|
|
|
84 |
state_dict = torch.load(decoder_path)
|
85 |
decoder.load_state_dict(state_dict)
|
86 |
|
|
|
87 |
state_dict = torch.load(Trans_path)
|
88 |
Trans.load_state_dict(state_dict)
|
89 |
|
|
|
90 |
state_dict = torch.load(embedding_path)
|
91 |
embedding.load_state_dict(state_dict)
|
92 |
|
93 |
+
network = StyTR.StyTrans(vgg, decoder, embedding, Trans)
|
94 |
network.eval()
|
95 |
+
|
96 |
+
content_tf = content_transform()
|
97 |
+
style_tf = style_transform(style_size, style_size)
|
98 |
|
99 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
100 |
network.to(device)
|
101 |
+
content = content_tf(content_img.convert("RGB"))
|
102 |
style = style_tf(style_img.convert("RGB"))
|
103 |
style = style.to(device).unsqueeze(0)
|
104 |
content = content.to(device).unsqueeze(0)
|
105 |
with torch.no_grad():
|
106 |
+
output = network(content, style)
|
107 |
output = output[0].cpu().squeeze()
|
108 |
+
output = (
|
109 |
+
output.mul(255)
|
110 |
+
.add_(0.5)
|
111 |
+
.clamp_(0, 255)
|
112 |
+
.permute(1, 2, 0)
|
113 |
+
.to("cpu", torch.uint8)
|
114 |
+
.numpy()
|
115 |
+
)
|
116 |
return Image.fromarray(output)
|
|
|
|
|
|
|
117 |
|
118 |
+
|
119 |
+
# STYLE-FAST
|
120 |
+
|
121 |
+
style_transfer_model = tfhub.load(
|
122 |
+
"https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"
|
123 |
+
)
|
124 |
+
|
125 |
+
|
126 |
+
def StyleFAST(content_image: Image.Image, style_image: Image.Image) -> Image.Image:
|
127 |
+
"""
|
128 |
+
This function applies a Fast image style transfer technique,
|
129 |
+
which uses a pretrained model from tensorhub.
|
130 |
+
|
131 |
+
Parameters:
|
132 |
+
content_image = the image with the content
|
133 |
+
style_image = the image with the style/pattern
|
134 |
+
Returns:
|
135 |
+
stylized_image = an image that is a combination of both
|
136 |
+
"""
|
137 |
+
content_image = (
|
138 |
+
tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...]
|
139 |
+
/ 255.0
|
140 |
+
)
|
141 |
+
style_image = (
|
142 |
+
tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.0
|
143 |
+
)
|
144 |
output = style_transfer_model(content_image, style_image)
|
145 |
stylized_image = output[0]
|
146 |
return Image.fromarray(np.uint8(stylized_image[0] * 255))
|
147 |
|
148 |
+
|
149 |
+
# STYLE PROJECTION
|
150 |
+
|
151 |
stylepro_artistic = phub.Module(name="stylepro_artistic")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
|
154 |
+
def styleProjection(
|
155 |
+
content_image: Image.Image, style_image: Image.Image, alpha: float = 1.0
|
156 |
+
):
|
157 |
+
"""
|
158 |
+
This function uses parameter free style transfer,
|
159 |
+
based on a model from paddlehub.
|
160 |
+
There is an optional weight parameter alpha, which
|
161 |
+
allows to control the balance between image and style.
|
162 |
+
|
163 |
+
Parameters:
|
164 |
+
content_image = the image with the content
|
165 |
+
style_image = the image with the style/pattern
|
166 |
+
alpha = weight for the image vs style.
|
167 |
+
This should be a float between 0 and 1.
|
168 |
+
Returns:
|
169 |
+
result = an image that is a combination of both
|
170 |
+
"""
|
171 |
+
result = stylepro_artistic.style_transfer(
|
172 |
+
images=[
|
173 |
+
{
|
174 |
+
"content": np.array(content_image.convert("RGB"))[:, :, ::-1],
|
175 |
+
"styles": [np.array(style_image.convert("RGB"))[:, :, ::-1]],
|
176 |
+
}
|
177 |
+
],
|
178 |
+
alpha=alpha,
|
179 |
+
)
|
180 |
+
|
181 |
+
return Image.fromarray(np.uint8(result[0]["data"])[:, :, ::-1]).convert("RGB")
|
app.py
CHANGED
@@ -1,50 +1,71 @@
|
|
1 |
-
|
2 |
-
from typing import Tuple
|
3 |
-
|
4 |
-
import gradio as gr
|
5 |
import numpy as np
|
6 |
-
|
7 |
-
|
8 |
from Segmentation.segmentation import get_mask, replace_sofa
|
9 |
-
from StyleTransfer.styleTransfer import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
10 |
|
11 |
def fix_orient(img: Image.Image) -> Image.Image:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
for orientation in ExifTags.TAGS.keys():
|
13 |
-
if ExifTags.TAGS[orientation]==
|
|
|
14 |
break
|
15 |
-
|
16 |
-
|
|
|
17 |
info = dict(info.items())
|
18 |
-
orientation
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
|
|
27 |
return img
|
28 |
|
29 |
|
30 |
-
def resize_sofa(img: Image.Image) ->
|
31 |
"""
|
32 |
-
This function adds padding to make the original image square
|
33 |
-
|
34 |
-
such that it can be reverted later.
|
35 |
Parameters:
|
36 |
img = original image
|
37 |
Return:
|
38 |
-
|
39 |
-
box
|
40 |
"""
|
41 |
width, height = img.size
|
42 |
idx = np.argmin([width, height])
|
43 |
newsize = (640, 640) # parameters from test script
|
44 |
|
45 |
if idx == 0:
|
46 |
-
|
47 |
-
|
48 |
box = (
|
49 |
newsize[0] * (1 - width / height) // 2,
|
50 |
0,
|
@@ -52,22 +73,22 @@ def resize_sofa(img: Image.Image) -> Tuple[Image.Image, tuple]:
|
|
52 |
newsize[1],
|
53 |
)
|
54 |
else:
|
55 |
-
|
56 |
-
|
57 |
box = (
|
58 |
0,
|
59 |
newsize[1] * (1 - height / width) // 2,
|
60 |
newsize[0],
|
61 |
newsize[1] - newsize[1] * (1 - height / width) // 2,
|
62 |
)
|
63 |
-
|
64 |
-
return
|
65 |
|
66 |
|
67 |
def resize_style(img: Image.Image) -> Image.Image:
|
68 |
"""
|
69 |
-
This function generates a zoomed out version of
|
70 |
-
image and resizes it to a 640by640 square.
|
71 |
Parameters:
|
72 |
img = image containing the style/pattern
|
73 |
Return:
|
@@ -88,114 +109,249 @@ def resize_style(img: Image.Image) -> Image.Image:
|
|
88 |
top = 0
|
89 |
bottom = height
|
90 |
newsize = (640, 640) # parameters from test script
|
91 |
-
|
92 |
|
93 |
# Constructs a zoomed-out version
|
94 |
copies = 8
|
95 |
resize = (newsize[0] // copies, newsize[1] // copies)
|
96 |
-
|
97 |
-
|
98 |
for row in range(copies):
|
99 |
-
|
100 |
for column in range(copies):
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
return
|
105 |
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
""
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
-
|
138 |
-
|
139 |
-
|
140 |
-
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
[
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
]
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|