Sophie98 commited on
Commit
ab92204
1 Parent(s): 993904f

test commit

Browse files
Files changed (6) hide show
  1. .gitattributes +2 -0
  2. StyTR.py +230 -0
  3. sofaApp.py +50 -0
  4. style_example1.jpg +0 -0
  5. test.py +176 -0
  6. transformer.py +322 -0
.gitattributes CHANGED
@@ -25,3 +25,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
25
  *.zip filter=lfs diff=lfs merge=lfs -text
26
  *.zstandard filter=lfs diff=lfs merge=lfs -text
27
  *tfevents* filter=lfs diff=lfs merge=lfs -text
28
+
29
+ *.pth filter=lfs diff=lfs merge=lfs -text
StyTR.py ADDED
@@ -0,0 +1,230 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ import numpy as np
5
+ import box_ops
6
+ from misc import (NestedTensor, nested_tensor_from_tensor_list,
7
+ accuracy, get_world_size, interpolate,
8
+ is_dist_avail_and_initialized)
9
+ from function import normal,normal_style
10
+ from function import calc_mean_std
11
+ import scipy.stats as stats
12
+ from ViT_helper import DropPath, to_2tuple, trunc_normal_
13
+
14
+ class PatchEmbed(nn.Module):
15
+ """ Image to Patch Embedding
16
+ """
17
+ def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512):
18
+ super().__init__()
19
+ img_size = to_2tuple(img_size)
20
+ patch_size = to_2tuple(patch_size)
21
+ num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
22
+ self.img_size = img_size
23
+ self.patch_size = patch_size
24
+ self.num_patches = num_patches
25
+
26
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
27
+ self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
28
+
29
+ def forward(self, x):
30
+ B, C, H, W = x.shape
31
+ x = self.proj(x)
32
+
33
+ return x
34
+
35
+
36
+ decoder = nn.Sequential(
37
+ nn.ReflectionPad2d((1, 1, 1, 1)),
38
+ nn.Conv2d(512, 256, (3, 3)),
39
+ nn.ReLU(),
40
+ nn.Upsample(scale_factor=2, mode='nearest'),
41
+ nn.ReflectionPad2d((1, 1, 1, 1)),
42
+ nn.Conv2d(256, 256, (3, 3)),
43
+ nn.ReLU(),
44
+ nn.ReflectionPad2d((1, 1, 1, 1)),
45
+ nn.Conv2d(256, 256, (3, 3)),
46
+ nn.ReLU(),
47
+ nn.ReflectionPad2d((1, 1, 1, 1)),
48
+ nn.Conv2d(256, 256, (3, 3)),
49
+ nn.ReLU(),
50
+ nn.ReflectionPad2d((1, 1, 1, 1)),
51
+ nn.Conv2d(256, 128, (3, 3)),
52
+ nn.ReLU(),
53
+ nn.Upsample(scale_factor=2, mode='nearest'),
54
+ nn.ReflectionPad2d((1, 1, 1, 1)),
55
+ nn.Conv2d(128, 128, (3, 3)),
56
+ nn.ReLU(),
57
+ nn.ReflectionPad2d((1, 1, 1, 1)),
58
+ nn.Conv2d(128, 64, (3, 3)),
59
+ nn.ReLU(),
60
+ nn.Upsample(scale_factor=2, mode='nearest'),
61
+ nn.ReflectionPad2d((1, 1, 1, 1)),
62
+ nn.Conv2d(64, 64, (3, 3)),
63
+ nn.ReLU(),
64
+ nn.ReflectionPad2d((1, 1, 1, 1)),
65
+ nn.Conv2d(64, 3, (3, 3)),
66
+ )
67
+
68
+ vgg = nn.Sequential(
69
+ nn.Conv2d(3, 3, (1, 1)),
70
+ nn.ReflectionPad2d((1, 1, 1, 1)),
71
+ nn.Conv2d(3, 64, (3, 3)),
72
+ nn.ReLU(), # relu1-1
73
+ nn.ReflectionPad2d((1, 1, 1, 1)),
74
+ nn.Conv2d(64, 64, (3, 3)),
75
+ nn.ReLU(), # relu1-2
76
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
77
+ nn.ReflectionPad2d((1, 1, 1, 1)),
78
+ nn.Conv2d(64, 128, (3, 3)),
79
+ nn.ReLU(), # relu2-1
80
+ nn.ReflectionPad2d((1, 1, 1, 1)),
81
+ nn.Conv2d(128, 128, (3, 3)),
82
+ nn.ReLU(), # relu2-2
83
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
84
+ nn.ReflectionPad2d((1, 1, 1, 1)),
85
+ nn.Conv2d(128, 256, (3, 3)),
86
+ nn.ReLU(), # relu3-1
87
+ nn.ReflectionPad2d((1, 1, 1, 1)),
88
+ nn.Conv2d(256, 256, (3, 3)),
89
+ nn.ReLU(), # relu3-2
90
+ nn.ReflectionPad2d((1, 1, 1, 1)),
91
+ nn.Conv2d(256, 256, (3, 3)),
92
+ nn.ReLU(), # relu3-3
93
+ nn.ReflectionPad2d((1, 1, 1, 1)),
94
+ nn.Conv2d(256, 256, (3, 3)),
95
+ nn.ReLU(), # relu3-4
96
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
97
+ nn.ReflectionPad2d((1, 1, 1, 1)),
98
+ nn.Conv2d(256, 512, (3, 3)),
99
+ nn.ReLU(), # relu4-1, this is the last layer used
100
+ nn.ReflectionPad2d((1, 1, 1, 1)),
101
+ nn.Conv2d(512, 512, (3, 3)),
102
+ nn.ReLU(), # relu4-2
103
+ nn.ReflectionPad2d((1, 1, 1, 1)),
104
+ nn.Conv2d(512, 512, (3, 3)),
105
+ nn.ReLU(), # relu4-3
106
+ nn.ReflectionPad2d((1, 1, 1, 1)),
107
+ nn.Conv2d(512, 512, (3, 3)),
108
+ nn.ReLU(), # relu4-4
109
+ nn.MaxPool2d((2, 2), (2, 2), (0, 0), ceil_mode=True),
110
+ nn.ReflectionPad2d((1, 1, 1, 1)),
111
+ nn.Conv2d(512, 512, (3, 3)),
112
+ nn.ReLU(), # relu5-1
113
+ nn.ReflectionPad2d((1, 1, 1, 1)),
114
+ nn.Conv2d(512, 512, (3, 3)),
115
+ nn.ReLU(), # relu5-2
116
+ nn.ReflectionPad2d((1, 1, 1, 1)),
117
+ nn.Conv2d(512, 512, (3, 3)),
118
+ nn.ReLU(), # relu5-3
119
+ nn.ReflectionPad2d((1, 1, 1, 1)),
120
+ nn.Conv2d(512, 512, (3, 3)),
121
+ nn.ReLU() # relu5-4
122
+ )
123
+
124
+ class MLP(nn.Module):
125
+ """ Very simple multi-layer perceptron (also called FFN)"""
126
+
127
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
128
+ super().__init__()
129
+ self.num_layers = num_layers
130
+ h = [hidden_dim] * (num_layers - 1)
131
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
132
+
133
+ def forward(self, x):
134
+ for i, layer in enumerate(self.layers):
135
+ x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
136
+ return x
137
+ class StyTrans(nn.Module):
138
+ """ This is the style transform transformer module """
139
+
140
+ def __init__(self,encoder,decoder,PatchEmbed, transformer,args):
141
+
142
+ super().__init__()
143
+ enc_layers = list(encoder.children())
144
+ self.enc_1 = nn.Sequential(*enc_layers[:4]) # input -> relu1_1
145
+ self.enc_2 = nn.Sequential(*enc_layers[4:11]) # relu1_1 -> relu2_1
146
+ self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
147
+ self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
148
+ self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
149
+
150
+ for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
151
+ for param in getattr(self, name).parameters():
152
+ param.requires_grad = False
153
+
154
+ self.mse_loss = nn.MSELoss()
155
+ self.transformer = transformer
156
+ hidden_dim = transformer.d_model
157
+ self.decode = decoder
158
+ self.embedding = PatchEmbed
159
+
160
+ def encode_with_intermediate(self, input):
161
+ results = [input]
162
+ for i in range(5):
163
+ func = getattr(self, 'enc_{:d}'.format(i + 1))
164
+ results.append(func(results[-1]))
165
+ return results[1:]
166
+
167
+ def calc_content_loss(self, input, target):
168
+ assert (input.size() == target.size())
169
+ assert (target.requires_grad is False)
170
+ return self.mse_loss(input, target)
171
+
172
+ def calc_style_loss(self, input, target):
173
+ assert (input.size() == target.size())
174
+ assert (target.requires_grad is False)
175
+ input_mean, input_std = calc_mean_std(input)
176
+ target_mean, target_std = calc_mean_std(target)
177
+ return self.mse_loss(input_mean, target_mean) + \
178
+ self.mse_loss(input_std, target_std)
179
+ def forward(self, samples_c: NestedTensor,samples_s: NestedTensor):
180
+ """ The forward expects a NestedTensor, which consists of:
181
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
182
+ - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
183
+
184
+ """
185
+ content_input = samples_c
186
+ style_input = samples_s
187
+ if isinstance(samples_c, (list, torch.Tensor)):
188
+ samples_c = nested_tensor_from_tensor_list(samples_c) # support different-sized images padding is used for mask [tensor, mask]
189
+ if isinstance(samples_s, (list, torch.Tensor)):
190
+ samples_s = nested_tensor_from_tensor_list(samples_s)
191
+
192
+ # ### features used to calcate loss
193
+ content_feats = self.encode_with_intermediate(samples_c.tensors)
194
+ style_feats = self.encode_with_intermediate(samples_s.tensors)
195
+
196
+ ### Linear projection
197
+ style = self.embedding(samples_s.tensors)
198
+ content = self.embedding(samples_c.tensors)
199
+
200
+ # postional embedding is calculated in transformer.py
201
+ pos_s = None
202
+ pos_c = None
203
+
204
+ mask = None
205
+ hs = self.transformer(style, mask , content, pos_c, pos_s)
206
+ Ics = self.decode(hs)
207
+
208
+ Ics_feats = self.encode_with_intermediate(Ics)
209
+ loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
210
+ # Style loss
211
+ loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
212
+ for i in range(1, 5):
213
+ loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
214
+
215
+
216
+ Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c))
217
+ Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s))
218
+
219
+ #Identity losses lambda 1
220
+ loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input)
221
+
222
+ #Identity losses lambda 2
223
+ Icc_feats=self.encode_with_intermediate(Icc)
224
+ Iss_feats=self.encode_with_intermediate(Iss)
225
+ loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0])
226
+ for i in range(1, 5):
227
+ loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i])
228
+ # Please select and comment out one of the following two sentences
229
+ return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 #train
230
+ # return Ics #test
sofaApp.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ from Segmentation.segmentation import get_mask,replace_sofa
4
+ from StyleTransfer.styleTransfer import resize_sofa,resize_style,create_styledSofa
5
+ from PIL import Image
6
+
7
+ def style_sofa(input_img: np.ndarray, style_img: np.ndarray):
8
+ """
9
+ Styles (all) the sofas in the image to the given style.
10
+ This function uses a transformer to combine the image with the desired style according
11
+ to a generated mask of the sofas in the image.
12
+ Input:
13
+ input_img = image containing a sofa
14
+ style_img = image containing a style
15
+ Return:
16
+ new_sofa = image containing the styled sofa
17
+ """
18
+
19
+ # preprocess input images to be (640,640) squares to fit requirements of the segmentation model
20
+ resized_img = resize_sofa(input_img)
21
+ resized_style = resize_style(style_img)
22
+ # generate mask for image
23
+ mask = get_mask(resized_img)
24
+ styled_sofa = create_styledSofa(resized_img,resized_style)
25
+ new_sofa = replace_sofa(resized_img,mask,styled_sofa)
26
+ return new_sofa
27
+
28
+ image = gr.inputs.Image()
29
+ style = gr.inputs.Image()
30
+
31
+ demo = gr.Interface(
32
+ style_sofa,
33
+ [image,style],
34
+ 'image',
35
+ examples=[
36
+ ['input/sofa_example1.jpg','input/style_example1.jpg'],
37
+ ['input/sofa_example1.jpg','input/style_example2.jpg'],
38
+ ['input/sofa_example1.jpg','input/style_example3.jpg'],
39
+ ['input/sofa_example1.jpg','input/style_example4.jpg'],
40
+ ['input/sofa_example1.jpg','input/style_example5.jpg'],
41
+ ],
42
+ title="Style your sofa",
43
+ description="🛋 Customize your sofa to your wildest dreams! 🛋",
44
+ )
45
+
46
+ if __name__ == "__main__":
47
+ demo.launch(share=True)
48
+
49
+
50
+ #https://github.com/dhawan98/Post-Processing-of-Image-Segmentation-using-CRF
style_example1.jpg ADDED
test.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from pathlib import Path
3
+ import os
4
+ import torch
5
+ import torch.nn as nn
6
+ from PIL import Image
7
+ from os.path import basename
8
+ from os.path import splitext
9
+ from torchvision import transforms
10
+ from torchvision.utils import save_image
11
+ from function import calc_mean_std, normal, coral
12
+ import transformer as transformer
13
+ import StyTR as StyTR
14
+ import matplotlib.pyplot as plt
15
+ from matplotlib import cm
16
+ from function import normal
17
+ import numpy as np
18
+
19
+ def test_transform(size, crop):
20
+ transform_list = []
21
+
22
+ if size != 0:
23
+ transform_list.append(transforms.Resize(size))
24
+ if crop:
25
+ transform_list.append(transforms.CenterCrop(size))
26
+ transform_list.append(transforms.ToTensor())
27
+ transform = transforms.Compose(transform_list)
28
+ return transform
29
+ def style_transform(h,w):
30
+ k = (h,w)
31
+ size = int(np.max(k))
32
+ transform_list = []
33
+ transform_list.append(transforms.CenterCrop((h,w)))
34
+ transform_list.append(transforms.ToTensor())
35
+ transform = transforms.Compose(transform_list)
36
+ return transform
37
+
38
+ def content_transform():
39
+
40
+ transform_list = []
41
+ transform_list.append(transforms.ToTensor())
42
+ transform = transforms.Compose(transform_list)
43
+ return transform
44
+
45
+
46
+
47
+ parser = argparse.ArgumentParser()
48
+ # Basic options
49
+ parser.add_argument('--content', type=str,
50
+ help='File path to the content image')
51
+ parser.add_argument('--content_dir', type=str,
52
+ help='Directory path to a batch of content images')
53
+ parser.add_argument('--style', type=str,
54
+ help='File path to the style image, or multiple style \
55
+ images separated by commas if you want to do style \
56
+ interpolation or spatial control')
57
+ parser.add_argument('--style_dir', type=str,
58
+ help='Directory path to a batch of style images')
59
+ parser.add_argument('--output', type=str, default='output',
60
+ help='Directory to save the output image(s)')
61
+ parser.add_argument('--vgg', type=str, default='./experiments/vgg_normalised.pth')
62
+ parser.add_argument('--decoder_path', type=str, default='experiments/decoder_iter_160000.pth')
63
+ parser.add_argument('--Trans_path', type=str, default='experiments/transformer_iter_160000.pth')
64
+ parser.add_argument('--embedding_path', type=str, default='experiments/embedding_iter_160000.pth')
65
+
66
+
67
+ parser.add_argument('--style_interpolation_weights', type=str, default="")
68
+ parser.add_argument('--a', type=float, default=1.0)
69
+ parser.add_argument('--position_embedding', default='sine', type=str, choices=('sine', 'learned'),
70
+ help="Type of positional embedding to use on top of the image features")
71
+ parser.add_argument('--hidden_dim', default=512, type=int,
72
+ help="Size of the embeddings (dimension of the transformer)")
73
+ args = parser.parse_args()
74
+
75
+ # Advanced options
76
+ content_size=640
77
+ style_size=640
78
+ crop='store_true'
79
+ save_ext='.jpg'
80
+ output_path=args.output
81
+ preserve_color='store_true'
82
+ alpha=args.a
83
+
84
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
85
+
86
+ # Either --content or --content_dir should be given.
87
+ if args.content:
88
+ content_paths = [Path(args.content)]
89
+ else:
90
+ content_dir = Path(args.content_dir)
91
+ content_paths = [f for f in content_dir.glob('*')]
92
+
93
+ # Either --style or --style_dir should be given.
94
+ if args.style:
95
+ style_paths = [Path(args.style)]
96
+ else:
97
+ style_dir = Path(args.style_dir)
98
+ style_paths = [f for f in style_dir.glob('*')]
99
+
100
+ if not os.path.exists(output_path):
101
+ os.mkdir(output_path)
102
+
103
+
104
+ vgg = StyTR.vgg
105
+ vgg.load_state_dict(torch.load(args.vgg))
106
+ vgg = nn.Sequential(*list(vgg.children())[:44])
107
+
108
+ decoder = StyTR.decoder
109
+ Trans = transformer.Transformer()
110
+ embedding = StyTR.PatchEmbed()
111
+
112
+ decoder.eval()
113
+ Trans.eval()
114
+ vgg.eval()
115
+ from collections import OrderedDict
116
+ new_state_dict = OrderedDict()
117
+ state_dict = torch.load(args.decoder_path)
118
+ for k, v in state_dict.items():
119
+ #namekey = k[7:] # remove `module.`
120
+ namekey = k
121
+ new_state_dict[namekey] = v
122
+ decoder.load_state_dict(new_state_dict)
123
+
124
+ new_state_dict = OrderedDict()
125
+ state_dict = torch.load(args.Trans_path)
126
+ for k, v in state_dict.items():
127
+ #namekey = k[7:] # remove `module.`
128
+ namekey = k
129
+ new_state_dict[namekey] = v
130
+ Trans.load_state_dict(new_state_dict)
131
+
132
+ new_state_dict = OrderedDict()
133
+ state_dict = torch.load(args.embedding_path)
134
+ for k, v in state_dict.items():
135
+ #namekey = k[7:] # remove `module.`
136
+ namekey = k
137
+ new_state_dict[namekey] = v
138
+ embedding.load_state_dict(new_state_dict)
139
+
140
+ network = StyTR.StyTrans(vgg,decoder,embedding,Trans,args)
141
+ network.eval()
142
+ network.to(device)
143
+
144
+
145
+
146
+ content_tf = test_transform(content_size, crop)
147
+ style_tf = test_transform(style_size, crop)
148
+
149
+ for content_path in content_paths:
150
+ for style_path in style_paths:
151
+ print(content_path)
152
+
153
+
154
+ content_tf1 = content_transform()
155
+ content = content_tf(Image.open(content_path).convert("RGB"))
156
+
157
+ h,w,c=np.shape(content)
158
+ style_tf1 = style_transform(h,w)
159
+ style = style_tf(Image.open(style_path).convert("RGB"))
160
+
161
+
162
+ style = style.to(device).unsqueeze(0)
163
+ content = content.to(device).unsqueeze(0)
164
+
165
+ with torch.no_grad():
166
+ output= network(content,style)
167
+ output = output[0].cpu()
168
+
169
+ output_name = '{:s}/{:s}_stylized_{:s}{:s}'.format(
170
+ output_path, splitext(basename(content_path))[0],
171
+ splitext(basename(style_path))[0], save_ext
172
+ )
173
+
174
+ save_image(output, output_name)
175
+
176
+
transformer.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from typing import Optional, List
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn, Tensor
7
+ from function import normal,normal_style
8
+ import numpy as np
9
+ import os
10
+ device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
11
+ os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
12
+ class Transformer(nn.Module):
13
+
14
+ def __init__(self, d_model=512, nhead=8, num_encoder_layers=3,
15
+ num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
16
+ activation="relu", normalize_before=False,
17
+ return_intermediate_dec=False):
18
+ super().__init__()
19
+
20
+ encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
21
+ dropout, activation, normalize_before)
22
+ encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
23
+ self.encoder_c = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
24
+ self.encoder_s = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
25
+
26
+ decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
27
+ dropout, activation, normalize_before)
28
+ decoder_norm = nn.LayerNorm(d_model)
29
+ self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
30
+ return_intermediate=return_intermediate_dec)
31
+
32
+ self._reset_parameters()
33
+
34
+ self.d_model = d_model
35
+ self.nhead = nhead
36
+
37
+ self.new_ps = nn.Conv2d(512 , 512 , (1,1))
38
+ self.averagepooling = nn.AdaptiveAvgPool2d(18)
39
+
40
+ def _reset_parameters(self):
41
+ for p in self.parameters():
42
+ if p.dim() > 1:
43
+ nn.init.xavier_uniform_(p)
44
+
45
+ def forward(self, style, mask , content, pos_embed_c, pos_embed_s):
46
+
47
+ # content-aware positional embedding
48
+ content_pool = self.averagepooling(content)
49
+ pos_c = self.new_ps(content_pool)
50
+ pos_embed_c = F.interpolate(pos_c, mode='bilinear',size= style.shape[-2:])
51
+
52
+ ###flatten NxCxHxW to HWxNxC
53
+ style = style.flatten(2).permute(2, 0, 1)
54
+ if pos_embed_s is not None:
55
+ pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
56
+
57
+ content = content.flatten(2).permute(2, 0, 1)
58
+ if pos_embed_c is not None:
59
+ pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
60
+
61
+
62
+ style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
63
+ content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
64
+ hs = self.decoder(content, style, memory_key_padding_mask=mask,
65
+ pos=pos_embed_s, query_pos=pos_embed_c)[0]
66
+
67
+ ### HWxNxC to NxCxHxW to
68
+ N, B, C= hs.shape
69
+ H = int(np.sqrt(N))
70
+ hs = hs.permute(1, 2, 0)
71
+ hs = hs.view(B, C, -1,H)
72
+
73
+ return hs
74
+
75
+
76
+ class TransformerEncoder(nn.Module):
77
+
78
+ def __init__(self, encoder_layer, num_layers, norm=None):
79
+ super().__init__()
80
+ self.layers = _get_clones(encoder_layer, num_layers)
81
+ self.num_layers = num_layers
82
+ self.norm = norm
83
+
84
+ def forward(self, src,
85
+ mask: Optional[Tensor] = None,
86
+ src_key_padding_mask: Optional[Tensor] = None,
87
+ pos: Optional[Tensor] = None):
88
+ output = src
89
+
90
+ for layer in self.layers:
91
+ output = layer(output, src_mask=mask,
92
+ src_key_padding_mask=src_key_padding_mask, pos=pos)
93
+
94
+ if self.norm is not None:
95
+ output = self.norm(output)
96
+
97
+ return output
98
+
99
+
100
+ class TransformerDecoder(nn.Module):
101
+
102
+ def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
103
+ super().__init__()
104
+ self.layers = _get_clones(decoder_layer, num_layers)
105
+ self.num_layers = num_layers
106
+ self.norm = norm
107
+ self.return_intermediate = return_intermediate
108
+
109
+ def forward(self, tgt, memory,
110
+ tgt_mask: Optional[Tensor] = None,
111
+ memory_mask: Optional[Tensor] = None,
112
+ tgt_key_padding_mask: Optional[Tensor] = None,
113
+ memory_key_padding_mask: Optional[Tensor] = None,
114
+ pos: Optional[Tensor] = None,
115
+ query_pos: Optional[Tensor] = None):
116
+ output = tgt
117
+
118
+ intermediate = []
119
+
120
+ for layer in self.layers:
121
+ output = layer(output, memory, tgt_mask=tgt_mask,
122
+ memory_mask=memory_mask,
123
+ tgt_key_padding_mask=tgt_key_padding_mask,
124
+ memory_key_padding_mask=memory_key_padding_mask,
125
+ pos=pos, query_pos=query_pos)
126
+ if self.return_intermediate:
127
+ intermediate.append(self.norm(output))
128
+
129
+ if self.norm is not None:
130
+ output = self.norm(output)
131
+ if self.return_intermediate:
132
+ intermediate.pop()
133
+ intermediate.append(output)
134
+
135
+ if self.return_intermediate:
136
+ return torch.stack(intermediate)
137
+
138
+ return output.unsqueeze(0)
139
+
140
+
141
+ class TransformerEncoderLayer(nn.Module):
142
+
143
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
144
+ activation="relu", normalize_before=False):
145
+ super().__init__()
146
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
147
+ # Implementation of Feedforward model
148
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
149
+ self.dropout = nn.Dropout(dropout)
150
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
151
+
152
+ self.norm1 = nn.LayerNorm(d_model)
153
+ self.norm2 = nn.LayerNorm(d_model)
154
+ self.dropout1 = nn.Dropout(dropout)
155
+ self.dropout2 = nn.Dropout(dropout)
156
+
157
+ self.activation = _get_activation_fn(activation)
158
+ self.normalize_before = normalize_before
159
+
160
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
161
+ return tensor if pos is None else tensor + pos
162
+
163
+ def forward_post(self,
164
+ src,
165
+ src_mask: Optional[Tensor] = None,
166
+ src_key_padding_mask: Optional[Tensor] = None,
167
+ pos: Optional[Tensor] = None):
168
+ q = k = self.with_pos_embed(src, pos)
169
+ # q = k = src
170
+ # print(q.size(),k.size(),src.size())
171
+ src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
172
+ key_padding_mask=src_key_padding_mask)[0]
173
+ src = src + self.dropout1(src2)
174
+ src = self.norm1(src)
175
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
176
+ src = src + self.dropout2(src2)
177
+ src = self.norm2(src)
178
+ return src
179
+
180
+ def forward_pre(self, src,
181
+ src_mask: Optional[Tensor] = None,
182
+ src_key_padding_mask: Optional[Tensor] = None,
183
+ pos: Optional[Tensor] = None):
184
+ src2 = self.norm1(src)
185
+ q = k = self.with_pos_embed(src2, pos)
186
+ src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
187
+ key_padding_mask=src_key_padding_mask)[0]
188
+ src = src + self.dropout1(src2)
189
+ src2 = self.norm2(src)
190
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
191
+ src = src + self.dropout2(src2)
192
+ return src
193
+
194
+ def forward(self, src,
195
+ src_mask: Optional[Tensor] = None,
196
+ src_key_padding_mask: Optional[Tensor] = None,
197
+ pos: Optional[Tensor] = None):
198
+ if self.normalize_before:
199
+ return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
200
+ return self.forward_post(src, src_mask, src_key_padding_mask, pos)
201
+
202
+
203
+ class TransformerDecoderLayer(nn.Module):
204
+
205
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
206
+ activation="relu", normalize_before=False):
207
+ super().__init__()
208
+ # d_model embedding dim
209
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
210
+ self.multihead_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
211
+ # Implementation of Feedforward model
212
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
213
+ self.dropout = nn.Dropout(dropout)
214
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
215
+
216
+ self.norm1 = nn.LayerNorm(d_model)
217
+ self.norm2 = nn.LayerNorm(d_model)
218
+ self.norm3 = nn.LayerNorm(d_model)
219
+ self.dropout1 = nn.Dropout(dropout)
220
+ self.dropout2 = nn.Dropout(dropout)
221
+ self.dropout3 = nn.Dropout(dropout)
222
+
223
+ self.activation = _get_activation_fn(activation)
224
+ self.normalize_before = normalize_before
225
+
226
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
227
+ return tensor if pos is None else tensor + pos
228
+
229
+ def forward_post(self, tgt, memory,
230
+ tgt_mask: Optional[Tensor] = None,
231
+ memory_mask: Optional[Tensor] = None,
232
+ tgt_key_padding_mask: Optional[Tensor] = None,
233
+ memory_key_padding_mask: Optional[Tensor] = None,
234
+ pos: Optional[Tensor] = None,
235
+ query_pos: Optional[Tensor] = None):
236
+
237
+
238
+ q = self.with_pos_embed(tgt, query_pos)
239
+ k = self.with_pos_embed(memory, pos)
240
+ v = memory
241
+
242
+ tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
243
+ key_padding_mask=tgt_key_padding_mask)[0]
244
+
245
+ tgt = tgt + self.dropout1(tgt2)
246
+ tgt = self.norm1(tgt)
247
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
248
+ key=self.with_pos_embed(memory, pos),
249
+ value=memory, attn_mask=memory_mask,
250
+ key_padding_mask=memory_key_padding_mask)[0]
251
+ tgt = tgt + self.dropout2(tgt2)
252
+ tgt = self.norm2(tgt)
253
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
254
+ tgt = tgt + self.dropout3(tgt2)
255
+ tgt = self.norm3(tgt)
256
+ return tgt
257
+
258
+ def forward_pre(self, tgt, memory,
259
+ tgt_mask: Optional[Tensor] = None,
260
+ memory_mask: Optional[Tensor] = None,
261
+ tgt_key_padding_mask: Optional[Tensor] = None,
262
+ memory_key_padding_mask: Optional[Tensor] = None,
263
+ pos: Optional[Tensor] = None,
264
+ query_pos: Optional[Tensor] = None):
265
+ tgt2 = self.norm1(tgt)
266
+ q = k = self.with_pos_embed(tgt2, query_pos)
267
+ tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
268
+ key_padding_mask=tgt_key_padding_mask)[0]
269
+
270
+ tgt = tgt + self.dropout1(tgt2)
271
+ tgt2 = self.norm2(tgt)
272
+ tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
273
+ key=self.with_pos_embed(memory, pos),
274
+ value=memory, attn_mask=memory_mask,
275
+ key_padding_mask=memory_key_padding_mask)[0]
276
+
277
+ tgt = tgt + self.dropout2(tgt2)
278
+ tgt2 = self.norm3(tgt)
279
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
280
+ tgt = tgt + self.dropout3(tgt2)
281
+ return tgt
282
+
283
+ def forward(self, tgt, memory,
284
+ tgt_mask: Optional[Tensor] = None,
285
+ memory_mask: Optional[Tensor] = None,
286
+ tgt_key_padding_mask: Optional[Tensor] = None,
287
+ memory_key_padding_mask: Optional[Tensor] = None,
288
+ pos: Optional[Tensor] = None,
289
+ query_pos: Optional[Tensor] = None):
290
+ if self.normalize_before:
291
+ return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
292
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
293
+ return self.forward_post(tgt, memory, tgt_mask, memory_mask,
294
+ tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
295
+
296
+
297
+ def _get_clones(module, N):
298
+ return nn.ModuleList([copy.deepcopy(module) for i in range(N)])
299
+
300
+
301
+ def build_transformer(args):
302
+ return Transformer(
303
+ d_model=args.hidden_dim,
304
+ dropout=args.dropout,
305
+ nhead=args.nheads,
306
+ dim_feedforward=args.dim_feedforward,
307
+ num_encoder_layers=args.enc_layers,
308
+ num_decoder_layers=args.dec_layers,
309
+ normalize_before=args.pre_norm,
310
+ return_intermediate_dec=True,
311
+ )
312
+
313
+
314
+ def _get_activation_fn(activation):
315
+ """Return an activation function given a string"""
316
+ if activation == "relu":
317
+ return F.relu
318
+ if activation == "gelu":
319
+ return F.gelu
320
+ if activation == "glu":
321
+ return F.glu
322
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")