Spaces:
Build error
Build error
Sophie98
commited on
Commit
•
ab92204
1
Parent(s):
993904f
test commit
Browse files- .gitattributes +2 -0
- StyTR.py +230 -0
- sofaApp.py +50 -0
- style_example1.jpg +0 -0
- test.py +176 -0
- transformer.py +322 -0
.gitattributes
CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
25 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
26 |
*.zstandard filter=lfs diff=lfs merge=lfs -text
|
27 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
|
29 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
StyTR.py
ADDED
@@ -0,0 +1,230 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
import numpy as np
|
5 |
+
import box_ops
|
6 |
+
from misc import (NestedTensor, nested_tensor_from_tensor_list,
|
7 |
+
accuracy, get_world_size, interpolate,
|
8 |
+
is_dist_avail_and_initialized)
|
9 |
+
from function import normal,normal_style
|
10 |
+
from function import calc_mean_std
|
11 |
+
import scipy.stats as stats
|
12 |
+
from ViT_helper import DropPath, to_2tuple, trunc_normal_
|
13 |
+
|
14 |
+
class PatchEmbed(nn.Module):
|
15 |
+
""" Image to Patch Embedding
|
16 |
+
"""
|
17 |
+
def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512):
|
18 |
+
super().__init__()
|
19 |
+
img_size = to_2tuple(img_size)
|
20 |
+
patch_size = to_2tuple(patch_size)
|
21 |
+
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
22 |
+
self.img_size = img_size
|
23 |
+
self.patch_size = patch_size
|
24 |
+
self.num_patches = num_patches
|
25 |
+
|
26 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
27 |
+
self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
|
28 |
+
|
29 |
+
def forward(self, x):
|
30 |
+
B, C, H, W = x.shape
|
31 |
+
x = self.proj(x)
|
32 |
+
|
33 |
+
return x
|
34 |
+
|
35 |
+
|
36 |
+
decoder = nn.Sequential(
|
37 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
38 |
+
nn.Conv2d(512, 256, (3, 3)),
|
39 |
+
nn.ReLU(),
|
40 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
41 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
42 |
+
nn.Conv2d(256, 256, (3, 3)),
|
43 |
+
nn.ReLU(),
|
44 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
45 |
+
nn.Conv2d(256, 256, (3, 3)),
|
46 |
+
nn.ReLU(),
|
47 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
48 |
+
nn.Conv2d(256, 256, (3, 3)),
|
49 |
+
nn.ReLU(),
|
50 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
51 |
+
nn.Conv2d(256, 128, (3, 3)),
|
52 |
+
nn.ReLU(),
|
53 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
54 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
55 |
+
nn.Conv2d(128, 128, (3, 3)),
|
56 |
+
nn.ReLU(),
|
57 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
58 |
+
nn.Conv2d(128, 64, (3, 3)),
|
59 |
+
nn.ReLU(),
|
60 |
+
nn.Upsample(scale_factor=2, mode='nearest'),
|
61 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
62 |
+
nn.Conv2d(64, 64, (3, 3)),
|
63 |
+
nn.ReLU(),
|
64 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
65 |
+
nn.Conv2d(64, 3, (3, 3)),
|
66 |
+
)
|
67 |
+
|
68 |
+
vgg = nn.Sequential(
|
69 |
+
nn.Conv2d(3, 3, (1, 1)),
|
70 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
71 |
+
nn.Conv2d(3, 64, (3, 3)),
|
72 |
+
nn.ReLU(), # relu1-1
|
73 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
74 |
+
nn.Conv2d(64, 64, (3, 3)),
|
75 |
+
nn.ReLU(), # relu1-2
|
76 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
77 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
78 |
+
nn.Conv2d(64, 128, (3, 3)),
|
79 |
+
nn.ReLU(), # relu2-1
|
80 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
81 |
+
nn.Conv2d(128, 128, (3, 3)),
|
82 |
+
nn.ReLU(), # relu2-2
|
83 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
84 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
85 |
+
nn.Conv2d(128, 256, (3, 3)),
|
86 |
+
nn.ReLU(), # relu3-1
|
87 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
88 |
+
nn.Conv2d(256, 256, (3, 3)),
|
89 |
+
nn.ReLU(), # relu3-2
|
90 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
91 |
+
nn.Conv2d(256, 256, (3, 3)),
|
92 |
+
nn.ReLU(), # relu3-3
|
93 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
94 |
+
nn.Conv2d(256, 256, (3, 3)),
|
95 |
+
nn.ReLU(), # relu3-4
|
96 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
97 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
98 |
+
nn.Conv2d(256, 512, (3, 3)),
|
99 |
+
nn.ReLU(), # relu4-1, this is the last layer used
|
100 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
101 |
+
nn.Conv2d(512, 512, (3, 3)),
|
102 |
+
nn.ReLU(), # relu4-2
|
103 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
104 |
+
nn.Conv2d(512, 512, (3, 3)),
|
105 |
+
nn.ReLU(), # relu4-3
|
106 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
107 |
+
nn.Conv2d(512, 512, (3, 3)),
|
108 |
+
nn.ReLU(), # relu4-4
|
109 |
+
nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
|
110 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
111 |
+
nn.Conv2d(512, 512, (3, 3)),
|
112 |
+
nn.ReLU(), # relu5-1
|
113 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
114 |
+
nn.Conv2d(512, 512, (3, 3)),
|
115 |
+
nn.ReLU(), # relu5-2
|
116 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
117 |
+
nn.Conv2d(512, 512, (3, 3)),
|
118 |
+
nn.ReLU(), # relu5-3
|
119 |
+
nn.ReflectionPad2d((1, 1, 1, 1)),
|
120 |
+
nn.Conv2d(512, 512, (3, 3)),
|
121 |
+
nn.ReLU() # relu5-4
|
122 |
+
)
|
123 |
+
|
124 |
+
class MLP(nn.Module):
|
125 |
+
""" Very simple multi-layer perceptron (also called FFN)"""
|
126 |
+
|
127 |
+
def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
|
128 |
+
super().__init__()
|
129 |
+
self.num_layers = num_layers
|
130 |
+
h = [hidden_dim] * (num_layers - 1)
|
131 |
+
self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
|
132 |
+
|
133 |
+
def forward(self, x):
|
134 |
+
for i, layer in enumerate(self.layers):
|
135 |
+
x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
|
136 |
+
return x
|
137 |
+
class StyTrans(nn.Module):
|
138 |
+
""" This is the style transform transformer module """
|
139 |
+
|
140 |
+
def __init__(self,encoder,decoder,PatchEmbed, transformer,args):
|
141 |
+
|
142 |
+
super().__init__()
|
143 |
+
enc_layers = list(encoder.children())
|
144 |
+
self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
|
145 |
+
self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
|
146 |
+
self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
|
147 |
+
self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
|
148 |
+
self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
|
149 |
+
|
150 |
+
for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
|
151 |
+
for param in getattr(self, name).parameters():
|
152 |
+
param.requires_grad = False
|
153 |
+
|
154 |
+
self.mse_loss = nn.MSELoss()
|
155 |
+
self.transformer = transformer
|
156 |
+
hidden_dim = transformer.d_model
|
157 |
+
self.decode = decoder
|
158 |
+
self.embedding = PatchEmbed
|
159 |
+
|
160 |
+
def encode_with_intermediate(self, input):
|
161 |
+
results = [input]
|
162 |
+
for i in range(5):
|
163 |
+
func = getattr(self, 'enc_{:d}'.format(i + 1))
|
164 |
+
results.append(func(results[-1]))
|
165 |
+
return results[1:]
|
166 |
+
|
167 |
+
def calc_content_loss(self, input, target):
|
168 |
+
assert (input.size() == target.size())
|
169 |
+
assert (target.requires_grad is False)
|
170 |
+
return self.mse_loss(input, target)
|
171 |
+
|
172 |
+
def calc_style_loss(self, input, target):
|
173 |
+
assert (input.size() == target.size())
|
174 |
+
assert (target.requires_grad is False)
|
175 |
+
input_mean, input_std = calc_mean_std(input)
|
176 |
+
target_mean, target_std = calc_mean_std(target)
|
177 |
+
return self.mse_loss(input_mean, target_mean) + \
|
178 |
+
self.mse_loss(input_std, target_std)
|
179 |
+
def forward(self, samples_c: NestedTensor,samples_s: NestedTensor):
|
180 |
+
""" The forward expects a NestedTensor, which consists of:
|
181 |
+
- samples.tensor: batched images, of shape [batch_size x 3 x H x W]
|
182 |
+
- samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
|
183 |
+
|
184 |
+
"""
|
185 |
+
content_input = samples_c
|
186 |
+
style_input = samples_s
|
187 |
+
if isinstance(samples_c, (list, torch.Tensor)):
|
188 |
+
samples_c = nested_tensor_from_tensor_list(samples_c) # support different-sized images padding is used for mask [tensor, mask]
|
189 |
+
if isinstance(samples_s, (list, torch.Tensor)):
|
190 |
+
samples_s = nested_tensor_from_tensor_list(samples_s)
|
191 |
+
|
192 |
+
# ### features used to calcate loss
|
193 |
+
content_feats = self.encode_with_intermediate(samples_c.tensors)
|
194 |
+
style_feats = self.encode_with_intermediate(samples_s.tensors)
|
195 |
+
|
196 |
+
### Linear projection
|
197 |
+
style = self.embedding(samples_s.tensors)
|
198 |
+
content = self.embedding(samples_c.tensors)
|
199 |
+
|
200 |
+
# postional embedding is calculated in transformer.py
|
201 |
+
pos_s = None
|
202 |
+
pos_c = None
|
203 |
+
|
204 |
+
mask = None
|
205 |
+
hs = self.transformer(style, mask , content, pos_c, pos_s)
|
206 |
+
Ics = self.decode(hs)
|
207 |
+
|
208 |
+
Ics_feats = self.encode_with_intermediate(Ics)
|
209 |
+
loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
|
210 |
+
# Style loss
|
211 |
+
loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
|
212 |
+
for i in range(1, 5):
|
213 |
+
loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
|
214 |
+
|
215 |
+
|
216 |
+
Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c))
|
217 |
+
Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s))
|
218 |
+
|
219 |
+
#Identity losses lambda 1
|
220 |
+
loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input)
|
221 |
+
|
222 |
+
#Identity losses lambda 2
|
223 |
+
Icc_feats=self.encode_with_intermediate(Icc)
|
224 |
+
Iss_feats=self.encode_with_intermediate(Iss)
|
225 |
+
loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0])
|
226 |
+
for i in range(1, 5):
|
227 |
+
loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i])
|
228 |
+
# Please select and comment out one of the following two sentences
|
229 |
+
return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 #train
|
230 |
+
# return Ics #test
|
sofaApp.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
from Segmentation.segmentation import get_mask,replace_sofa
|
4 |
+
from StyleTransfer.styleTransfer import resize_sofa,resize_style,create_styledSofa
|
5 |
+
from PIL import Image
|
6 |
+
|
7 |
+
def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
|
8 |
+
"""
|
9 |
+
Styles (all) the sofas in the image to the given style.
|
10 |
+
This function uses a transformer to combine the image with the desired style according
|
11 |
+
to a generated mask of the sofas in the image.
|
12 |
+
Input:
|
13 |
+
input_img = image containing a sofa
|
14 |
+
style_img = image containing a style
|
15 |
+
Return:
|
16 |
+
new_sofa = image containing the styled sofa
|
17 |
+
"""
|
18 |
+
|
19 |
+
# preprocess input images to be (640,640) squares to fit requirements of the segmentation model
|
20 |
+
resized_img = resize_sofa(input_img)
|
21 |
+
resized_style = resize_style(style_img)
|
22 |
+
# generate mask for image
|
23 |
+
mask = get_mask(resized_img)
|
24 |
+
styled_sofa = create_styledSofa(resized_img,resized_style)
|
25 |
+
new_sofa = replace_sofa(resized_img,mask,styled_sofa)
|
26 |
+
return new_sofa
|
27 |
+
|
28 |
+
image = gr.inputs.Image()
|
29 |
+
style = gr.inputs.Image()
|
30 |
+
|
31 |
+
demo = gr.Interface(
|
32 |
+
style_sofa,
|
33 |
+
[image,style],
|
34 |
+
'image',
|
35 |
+
examples=[
|
36 |
+
['input/sofa_example1.jpg','input/style_example1.jpg'],
|
37 |
+
['input/sofa_example1.jpg','input/style_example2.jpg'],
|
38 |
+
['input/sofa_example1.jpg','input/style_example3.jpg'],
|
39 |
+
['input/sofa_example1.jpg','input/style_example4.jpg'],
|
40 |
+
['input/sofa_example1.jpg','input/style_example5.jpg'],
|
41 |
+
],
|
42 |
+
title="Style your sofa",
|
43 |
+
description="🛋 Customize your sofa to your wildest dreams! 🛋",
|
44 |
+
)
|
45 |
+
|
46 |
+
if __name__ == "__main__":
|
47 |
+
demo.launch(share=True)
|
48 |
+
|
49 |
+
|
50 |
+
#https://github.com/dhawan98/Post-Processing-of-Image-Segmentation-using-CRF
|
style_example1.jpg
ADDED
test.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from pathlib import Path
|
3 |
+
import os
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from PIL import Image
|
7 |
+
from os.path import basename
|
8 |
+
from os.path import splitext
|
9 |
+
from torchvision import transforms
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
from function import calc_mean_std, normal, coral
|
12 |
+
import transformer as transformer
|
13 |
+
import StyTR as StyTR
|
14 |
+
import matplotlib.pyplot as plt
|
15 |
+
from matplotlib import cm
|
16 |
+
from function import normal
|
17 |
+
import numpy as np
|
18 |
+
|
19 |
+
def test_transform(size, crop):
|
20 |
+
transform_list = []
|
21 |
+
|
22 |
+
if size != 0:
|
23 |
+
transform_list.append(transforms.Resize(size))
|
24 |
+
if crop:
|
25 |
+
transform_list.append(transforms.CenterCrop(size))
|
26 |
+
transform_list.append(transforms.ToTensor())
|
27 |
+
transform = transforms.Compose(transform_list)
|
28 |
+
return transform
|
29 |
+
def style_transform(h,w):
|
30 |
+
k = (h,w)
|
31 |
+
size = int(np.max(k))
|
32 |
+
transform_list = []
|
33 |
+
transform_list.append(transforms.CenterCrop((h,w)))
|
34 |
+
transform_list.append(transforms.ToTensor())
|
35 |
+
transform = transforms.Compose(transform_list)
|
36 |
+
return transform
|
37 |
+
|
38 |
+
def content_transform():
|
39 |
+
|
40 |
+
transform_list = []
|
41 |
+
transform_list.append(transforms.ToTensor())
|
42 |
+
transform = transforms.Compose(transform_list)
|
43 |
+
return transform
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
parser = argparse.ArgumentParser()
|
48 |
+
# Basic options
|
49 |
+
parser.add_argument('--content', type=str,
|
50 |
+
help='File path to the content image')
|
51 |
+
parser.add_argument('--content_dir', type=str,
|
52 |
+
help='Directory path to a batch of content images')
|
53 |
+
parser.add_argument('--style', type=str,
|
54 |
+
help='File path to the style image, or multiple style \
|
55 |
+
images separated by commas if you want to do style \
|
56 |
+
interpolation or spatial control')
|
57 |
+
parser.add_argument('--style_dir', type=str,
|
58 |
+
help='Directory path to a batch of style images')
|
59 |
+
parser.add_argument('--output', type=str, default='output',
|
60 |
+
help='Directory to save the output image(s)')
|
61 |
+
parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
|
62 |
+
parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
|
63 |
+
parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
|
64 |
+
parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
|
65 |
+
|
66 |
+
|
67 |
+
parser.add_argument('--style_interpolation_weights', type=str, default="")
|
68 |
+
parser.add_argument('--a', type=float, default=1.0)
|
69 |
+
parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
|
70 |
+
help="Type of positional embedding to use on top of the image features")
|
71 |
+
parser.add_argument('--hidden_dim', default=512, type=int,
|
72 |
+
help="Size of the embeddings (dimension of the transformer)")
|
73 |
+
args = parser.parse_args()
|
74 |
+
|
75 |
+
# Advanced options
|
76 |
+
content_size=640
|
77 |
+
style_size=640
|
78 |
+
crop='store_true'
|
79 |
+
save_ext='.jpg'
|
80 |
+
output_path=args.output
|
81 |
+
preserve_color='store_true'
|
82 |
+
alpha=args.a
|
83 |
+
|
84 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
85 |
+
|
86 |
+
# Either --content or --content_dir should be given.
|
87 |
+
if args.content:
|
88 |
+
content_paths = [Path(args.content)]
|
89 |
+
else:
|
90 |
+
content_dir = Path(args.content_dir)
|
91 |
+
content_paths = [f for f in content_dir.glob('*')]
|
92 |
+
|
93 |
+
# Either --style or --style_dir should be given.
|
94 |
+
if args.style:
|
95 |
+
style_paths = [Path(args.style)]
|
96 |
+
else:
|
97 |
+
style_dir = Path(args.style_dir)
|
98 |
+
style_paths = [f for f in style_dir.glob('*')]
|
99 |
+
|
100 |
+
if not os.path.exists(output_path):
|
101 |
+
os.mkdir(output_path)
|
102 |
+
|
103 |
+
|
104 |
+
vgg = StyTR.vgg
|
105 |
+
vgg.load_state_dict(torch.load(args.vgg))
|
106 |
+
vgg = nn.Sequential(*list(vgg.children())[:44])
|
107 |
+
|
108 |
+
decoder = StyTR.decoder
|
109 |
+
Trans = transformer.Transformer()
|
110 |
+
embedding = StyTR.PatchEmbed()
|
111 |
+
|
112 |
+
decoder.eval()
|
113 |
+
Trans.eval()
|
114 |
+
vgg.eval()
|
115 |
+
from collections import OrderedDict
|
116 |
+
new_state_dict = OrderedDict()
|
117 |
+
state_dict = torch.load(args.decoder_path)
|
118 |
+
for k, v in state_dict.items():
|
119 |
+
#namekey = k[7:] # remove `module.`
|
120 |
+
namekey = k
|
121 |
+
new_state_dict[namekey] = v
|
122 |
+
decoder.load_state_dict(new_state_dict)
|
123 |
+
|
124 |
+
new_state_dict = OrderedDict()
|
125 |
+
state_dict = torch.load(args.Trans_path)
|
126 |
+
for k, v in state_dict.items():
|
127 |
+
#namekey = k[7:] # remove `module.`
|
128 |
+
namekey = k
|
129 |
+
new_state_dict[namekey] = v
|
130 |
+
Trans.load_state_dict(new_state_dict)
|
131 |
+
|
132 |
+
new_state_dict = OrderedDict()
|
133 |
+
state_dict = torch.load(args.embedding_path)
|
134 |
+
for k, v in state_dict.items():
|
135 |
+
#namekey = k[7:] # remove `module.`
|
136 |
+
namekey = k
|
137 |
+
new_state_dict[namekey] = v
|
138 |
+
embedding.load_state_dict(new_state_dict)
|
139 |
+
|
140 |
+
network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
|
141 |
+
network.eval()
|
142 |
+
network.to(device)
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
content_tf = test_transform(content_size, crop)
|
147 |
+
style_tf = test_transform(style_size, crop)
|
148 |
+
|
149 |
+
for content_path in content_paths:
|
150 |
+
for style_path in style_paths:
|
151 |
+
print(content_path)
|
152 |
+
|
153 |
+
|
154 |
+
content_tf1 = content_transform()
|
155 |
+
content = content_tf(Image.open(content_path).convert("RGB"))
|
156 |
+
|
157 |
+
h,w,c=np.shape(content)
|
158 |
+
style_tf1 = style_transform(h,w)
|
159 |
+
style = style_tf(Image.open(style_path).convert("RGB"))
|
160 |
+
|
161 |
+
|
162 |
+
style = style.to(device).unsqueeze(0)
|
163 |
+
content = content.to(device).unsqueeze(0)
|
164 |
+
|
165 |
+
with torch.no_grad():
|
166 |
+
output= network(content,style)
|
167 |
+
output = output[0].cpu()
|
168 |
+
|
169 |
+
output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
|
170 |
+
output_path, splitext(basename(content_path))[0],
|
171 |
+
splitext(basename(style_path))[0], save_ext
|
172 |
+
)
|
173 |
+
|
174 |
+
save_image(output, output_name)
|
175 |
+
|
176 |
+
|
transformer.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import copy
|
2 |
+
from typing import Optional, List
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn, Tensor
|
7 |
+
from function import normal,normal_style
|
8 |
+
import numpy as np
|
9 |
+
import os
|
10 |
+
device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
|
11 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
|
12 |
+
class Transformer(nn.Module):
|
13 |
+
|
14 |
+
def __init__(self, d_model=512, nhead=8, num_encoder_layers=3,
|
15 |
+
num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
|
16 |
+
activation="relu", normalize_before=False,
|
17 |
+
return_intermediate_dec=False):
|
18 |
+
super().__init__()
|
19 |
+
|
20 |
+
encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
|
21 |
+
dropout, activation, normalize_before)
|
22 |
+
encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
|
23 |
+
self.encoder_c = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
24 |
+
self.encoder_s = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
|
25 |
+
|
26 |
+
decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
|
27 |
+
dropout, activation, normalize_before)
|
28 |
+
decoder_norm = nn.LayerNorm(d_model)
|
29 |
+
self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
|
30 |
+
return_intermediate=return_intermediate_dec)
|
31 |
+
|
32 |
+
self._reset_parameters()
|
33 |
+
|
34 |
+
self.d_model = d_model
|
35 |
+
self.nhead = nhead
|
36 |
+
|
37 |
+
self.new_ps = nn.Conv2d(512 , 512 , (1,1))
|
38 |
+
self.averagepooling = nn.AdaptiveAvgPool2d(18)
|
39 |
+
|
40 |
+
def _reset_parameters(self):
|
41 |
+
for p in self.parameters():
|
42 |
+
if p.dim() > 1:
|
43 |
+
nn.init.xavier_uniform_(p)
|
44 |
+
|
45 |
+
def forward(self, style, mask , content, pos_embed_c, pos_embed_s):
|
46 |
+
|
47 |
+
# content-aware positional embedding
|
48 |
+
content_pool = self.averagepooling(content)
|
49 |
+
pos_c = self.new_ps(content_pool)
|
50 |
+
pos_embed_c = F.interpolate(pos_c, mode='bilinear',size= style.shape[-2:])
|
51 |
+
|
52 |
+
###flatten NxCxHxW to HWxNxC
|
53 |
+
style = style.flatten(2).permute(2, 0, 1)
|
54 |
+
if pos_embed_s is not None:
|
55 |
+
pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
|
56 |
+
|
57 |
+
content = content.flatten(2).permute(2, 0, 1)
|
58 |
+
if pos_embed_c is not None:
|
59 |
+
pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
|
60 |
+
|
61 |
+
|
62 |
+
style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
|
63 |
+
content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
|
64 |
+
hs = self.decoder(content, style, memory_key_padding_mask=mask,
|
65 |
+
pos=pos_embed_s, query_pos=pos_embed_c)[0]
|
66 |
+
|
67 |
+
### HWxNxC to NxCxHxW to
|
68 |
+
N, B, C= hs.shape
|
69 |
+
H = int(np.sqrt(N))
|
70 |
+
hs = hs.permute(1, 2, 0)
|
71 |
+
hs = hs.view(B, C, -1,H)
|
72 |
+
|
73 |
+
return hs
|
74 |
+
|
75 |
+
|
76 |
+
class TransformerEncoder(nn.Module):
|
77 |
+
|
78 |
+
def __init__(self, encoder_layer, num_layers, norm=None):
|
79 |
+
super().__init__()
|
80 |
+
self.layers = _get_clones(encoder_layer, num_layers)
|
81 |
+
self.num_layers = num_layers
|
82 |
+
self.norm = norm
|
83 |
+
|
84 |
+
def forward(self, src,
|
85 |
+
mask: Optional[Tensor] = None,
|
86 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
87 |
+
pos: Optional[Tensor] = None):
|
88 |
+
output = src
|
89 |
+
|
90 |
+
for layer in self.layers:
|
91 |
+
output = layer(output, src_mask=mask,
|
92 |
+
src_key_padding_mask=src_key_padding_mask, pos=pos)
|
93 |
+
|
94 |
+
if self.norm is not None:
|
95 |
+
output = self.norm(output)
|
96 |
+
|
97 |
+
return output
|
98 |
+
|
99 |
+
|
100 |
+
class TransformerDecoder(nn.Module):
|
101 |
+
|
102 |
+
def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
|
103 |
+
super().__init__()
|
104 |
+
self.layers = _get_clones(decoder_layer, num_layers)
|
105 |
+
self.num_layers = num_layers
|
106 |
+
self.norm = norm
|
107 |
+
self.return_intermediate = return_intermediate
|
108 |
+
|
109 |
+
def forward(self, tgt, memory,
|
110 |
+
tgt_mask: Optional[Tensor] = None,
|
111 |
+
memory_mask: Optional[Tensor] = None,
|
112 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
113 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
114 |
+
pos: Optional[Tensor] = None,
|
115 |
+
query_pos: Optional[Tensor] = None):
|
116 |
+
output = tgt
|
117 |
+
|
118 |
+
intermediate = []
|
119 |
+
|
120 |
+
for layer in self.layers:
|
121 |
+
output = layer(output, memory, tgt_mask=tgt_mask,
|
122 |
+
memory_mask=memory_mask,
|
123 |
+
tgt_key_padding_mask=tgt_key_padding_mask,
|
124 |
+
memory_key_padding_mask=memory_key_padding_mask,
|
125 |
+
pos=pos, query_pos=query_pos)
|
126 |
+
if self.return_intermediate:
|
127 |
+
intermediate.append(self.norm(output))
|
128 |
+
|
129 |
+
if self.norm is not None:
|
130 |
+
output = self.norm(output)
|
131 |
+
if self.return_intermediate:
|
132 |
+
intermediate.pop()
|
133 |
+
intermediate.append(output)
|
134 |
+
|
135 |
+
if self.return_intermediate:
|
136 |
+
return torch.stack(intermediate)
|
137 |
+
|
138 |
+
return output.unsqueeze(0)
|
139 |
+
|
140 |
+
|
141 |
+
class TransformerEncoderLayer(nn.Module):
|
142 |
+
|
143 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
144 |
+
activation="relu", normalize_before=False):
|
145 |
+
super().__init__()
|
146 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
147 |
+
# Implementation of Feedforward model
|
148 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
149 |
+
self.dropout = nn.Dropout(dropout)
|
150 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
151 |
+
|
152 |
+
self.norm1 = nn.LayerNorm(d_model)
|
153 |
+
self.norm2 = nn.LayerNorm(d_model)
|
154 |
+
self.dropout1 = nn.Dropout(dropout)
|
155 |
+
self.dropout2 = nn.Dropout(dropout)
|
156 |
+
|
157 |
+
self.activation = _get_activation_fn(activation)
|
158 |
+
self.normalize_before = normalize_before
|
159 |
+
|
160 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
161 |
+
return tensor if pos is None else tensor + pos
|
162 |
+
|
163 |
+
def forward_post(self,
|
164 |
+
src,
|
165 |
+
src_mask: Optional[Tensor] = None,
|
166 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
167 |
+
pos: Optional[Tensor] = None):
|
168 |
+
q = k = self.with_pos_embed(src, pos)
|
169 |
+
# q = k = src
|
170 |
+
# print(q.size(),k.size(),src.size())
|
171 |
+
src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
|
172 |
+
key_padding_mask=src_key_padding_mask)[0]
|
173 |
+
src = src + self.dropout1(src2)
|
174 |
+
src = self.norm1(src)
|
175 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
|
176 |
+
src = src + self.dropout2(src2)
|
177 |
+
src = self.norm2(src)
|
178 |
+
return src
|
179 |
+
|
180 |
+
def forward_pre(self, src,
|
181 |
+
src_mask: Optional[Tensor] = None,
|
182 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
183 |
+
pos: Optional[Tensor] = None):
|
184 |
+
src2 = self.norm1(src)
|
185 |
+
q = k = self.with_pos_embed(src2, pos)
|
186 |
+
src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
|
187 |
+
key_padding_mask=src_key_padding_mask)[0]
|
188 |
+
src = src + self.dropout1(src2)
|
189 |
+
src2 = self.norm2(src)
|
190 |
+
src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
|
191 |
+
src = src + self.dropout2(src2)
|
192 |
+
return src
|
193 |
+
|
194 |
+
def forward(self, src,
|
195 |
+
src_mask: Optional[Tensor] = None,
|
196 |
+
src_key_padding_mask: Optional[Tensor] = None,
|
197 |
+
pos: Optional[Tensor] = None):
|
198 |
+
if self.normalize_before:
|
199 |
+
return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
|
200 |
+
return self.forward_post(src, src_mask, src_key_padding_mask, pos)
|
201 |
+
|
202 |
+
|
203 |
+
class TransformerDecoderLayer(nn.Module):
|
204 |
+
|
205 |
+
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
|
206 |
+
activation="relu", normalize_before=False):
|
207 |
+
super().__init__()
|
208 |
+
# d_model embedding dim
|
209 |
+
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
210 |
+
self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
211 |
+
# Implementation of Feedforward model
|
212 |
+
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
213 |
+
self.dropout = nn.Dropout(dropout)
|
214 |
+
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
215 |
+
|
216 |
+
self.norm1 = nn.LayerNorm(d_model)
|
217 |
+
self.norm2 = nn.LayerNorm(d_model)
|
218 |
+
self.norm3 = nn.LayerNorm(d_model)
|
219 |
+
self.dropout1 = nn.Dropout(dropout)
|
220 |
+
self.dropout2 = nn.Dropout(dropout)
|
221 |
+
self.dropout3 = nn.Dropout(dropout)
|
222 |
+
|
223 |
+
self.activation = _get_activation_fn(activation)
|
224 |
+
self.normalize_before = normalize_before
|
225 |
+
|
226 |
+
def with_pos_embed(self, tensor, pos: Optional[Tensor]):
|
227 |
+
return tensor if pos is None else tensor + pos
|
228 |
+
|
229 |
+
def forward_post(self, tgt, memory,
|
230 |
+
tgt_mask: Optional[Tensor] = None,
|
231 |
+
memory_mask: Optional[Tensor] = None,
|
232 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
233 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
234 |
+
pos: Optional[Tensor] = None,
|
235 |
+
query_pos: Optional[Tensor] = None):
|
236 |
+
|
237 |
+
|
238 |
+
q = self.with_pos_embed(tgt, query_pos)
|
239 |
+
k = self.with_pos_embed(memory, pos)
|
240 |
+
v = memory
|
241 |
+
|
242 |
+
tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
|
243 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
244 |
+
|
245 |
+
tgt = tgt + self.dropout1(tgt2)
|
246 |
+
tgt = self.norm1(tgt)
|
247 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
|
248 |
+
key=self.with_pos_embed(memory, pos),
|
249 |
+
value=memory, attn_mask=memory_mask,
|
250 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
251 |
+
tgt = tgt + self.dropout2(tgt2)
|
252 |
+
tgt = self.norm2(tgt)
|
253 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
|
254 |
+
tgt = tgt + self.dropout3(tgt2)
|
255 |
+
tgt = self.norm3(tgt)
|
256 |
+
return tgt
|
257 |
+
|
258 |
+
def forward_pre(self, tgt, memory,
|
259 |
+
tgt_mask: Optional[Tensor] = None,
|
260 |
+
memory_mask: Optional[Tensor] = None,
|
261 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
262 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
263 |
+
pos: Optional[Tensor] = None,
|
264 |
+
query_pos: Optional[Tensor] = None):
|
265 |
+
tgt2 = self.norm1(tgt)
|
266 |
+
q = k = self.with_pos_embed(tgt2, query_pos)
|
267 |
+
tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
|
268 |
+
key_padding_mask=tgt_key_padding_mask)[0]
|
269 |
+
|
270 |
+
tgt = tgt + self.dropout1(tgt2)
|
271 |
+
tgt2 = self.norm2(tgt)
|
272 |
+
tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
|
273 |
+
key=self.with_pos_embed(memory, pos),
|
274 |
+
value=memory, attn_mask=memory_mask,
|
275 |
+
key_padding_mask=memory_key_padding_mask)[0]
|
276 |
+
|
277 |
+
tgt = tgt + self.dropout2(tgt2)
|
278 |
+
tgt2 = self.norm3(tgt)
|
279 |
+
tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
|
280 |
+
tgt = tgt + self.dropout3(tgt2)
|
281 |
+
return tgt
|
282 |
+
|
283 |
+
def forward(self, tgt, memory,
|
284 |
+
tgt_mask: Optional[Tensor] = None,
|
285 |
+
memory_mask: Optional[Tensor] = None,
|
286 |
+
tgt_key_padding_mask: Optional[Tensor] = None,
|
287 |
+
memory_key_padding_mask: Optional[Tensor] = None,
|
288 |
+
pos: Optional[Tensor] = None,
|
289 |
+
query_pos: Optional[Tensor] = None):
|
290 |
+
if self.normalize_before:
|
291 |
+
return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
|
292 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
293 |
+
return self.forward_post(tgt, memory, tgt_mask, memory_mask,
|
294 |
+
tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
|
295 |
+
|
296 |
+
|
297 |
+
def _get_clones(module, N):
|
298 |
+
return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
|
299 |
+
|
300 |
+
|
301 |
+
def build_transformer(args):
|
302 |
+
return Transformer(
|
303 |
+
d_model=args.hidden_dim,
|
304 |
+
dropout=args.dropout,
|
305 |
+
nhead=args.nheads,
|
306 |
+
dim_feedforward=args.dim_feedforward,
|
307 |
+
num_encoder_layers=args.enc_layers,
|
308 |
+
num_decoder_layers=args.dec_layers,
|
309 |
+
normalize_before=args.pre_norm,
|
310 |
+
return_intermediate_dec=True,
|
311 |
+
)
|
312 |
+
|
313 |
+
|
314 |
+
def _get_activation_fn(activation):
|
315 |
+
"""Return an activation function given a string"""
|
316 |
+
if activation == "relu":
|
317 |
+
return F.relu
|
318 |
+
if activation == "gelu":
|
319 |
+
return F.gelu
|
320 |
+
if activation == "glu":
|
321 |
+
return F.glu
|
322 |
+
raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
|