Sophie98 commited on
Commit
ad1ac8f
β€’
1 Parent(s): 6048967

change to streamlit

Browse files
.flake8 DELETED
@@ -1,12 +0,0 @@
1
- [flake8]
2
- exclude =
3
- .git,
4
- *.egg-info,
5
- __pycache__,
6
- .tox,
7
- .pytest_cache,
8
- build,
9
- dist,
10
- tests
11
- max-line-length = 88
12
- ignore = D202,W503,E203 # conflicts with black
 
 
 
 
 
 
 
 
 
 
 
 
 
.streamlit/config.toml ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [theme]
2
+ base="dark"
3
+ primaryColor="#04b188"
README.md CHANGED
@@ -3,8 +3,8 @@ title: SofaStyler
3
  emoji: πŸ›‹
4
  colorFrom: blue
5
  colorTo: green
6
- sdk: gradio
7
- sdk_version: 2.9.4
8
  app_file: app.py
9
  pinned: false
10
  ---
 
3
  emoji: πŸ›‹
4
  colorFrom: blue
5
  colorTo: green
6
+ sdk: streamlit
7
+ sdk_version: 1.9.0
8
  app_file: app.py
9
  pinned: false
10
  ---
Segmentation/{model_checkpoint.h5 β†’ model_final.h5} RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:ffd37e18fd15d753a6b7a9f8c589712508894fca6cbcfd2002d5053743788b70
3
- size 130752128
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a456f38c83897d9d8b5c8dd989ff7ee2fe13bb123a70a00b6e987d4efac1c6e
3
+ size 130858696
Segmentation/segmentation.py CHANGED
@@ -1,60 +1,45 @@
 
 
1
  import cv2
2
  from tensorflow import keras
3
  import numpy as np
4
  from PIL import Image
5
  import segmentation_models as sm
6
- sm.set_framework('tf.keras')
7
 
8
- # Load model at build time
9
- model_path = "Segmentation/model_checkpoint.h5"
10
- CLASSES = ['sofa']
11
- BACKBONE = 'resnet50'
12
 
13
- # define network parameters (only neede to load the weights)
14
- n_classes = 1 if len(CLASSES) == 1 else (len(CLASSES) + 1)
15
- activation = 'sigmoid' if n_classes == 1 else 'softmax'
16
  preprocess_input = sm.get_preprocessing(BACKBONE)
17
- LR=0.0001
18
- #create model architecture
19
- model = sm.Unet(BACKBONE, classes=n_classes, activation=activation)
20
- # define optomizer
21
- optim = keras.optimizers.Adam(LR)
22
- dice_loss = sm.losses.DiceLoss()
23
- focal_loss = sm.losses.BinaryFocalLoss() if n_classes == 1 else sm.losses.CategoricalFocalLoss()
24
- total_loss = dice_loss + (1 * focal_loss)
25
- metrics = [sm.metrics.IOUScore(threshold=0.5), sm.metrics.FScore(threshold=0.5)]
26
- # compile keras model with defined optimozer, loss and metrics
27
- model.compile(optim, total_loss, metrics)
28
- model.load_weights(model_path)
29
 
30
- def get_mask(image:Image.Image) -> Image.Image:
31
  """
32
- This function generates a mask of the image that highlights all the sofas in the image.
33
- This uses a pre-trained Unet model with a resnet50 backbone.
34
- Remark: The model was trained on 640by640 images and it is therefore best that the image has the same size.
 
 
35
  Parameters:
36
  image = original image
37
  Return:
38
  mask = corresponding maks of the image
39
  """
40
-
41
- # #load model
42
-
43
- #model = keras.models.load_model('model_final.h5', compile=False)
44
- print('loaded model')
45
- test_img = np.array(image)
46
  test_img = cv2.resize(test_img, (640, 640))
47
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
48
  test_img = np.expand_dims(test_img, axis=0)
49
 
50
  prediction = model.predict(preprocess_input(np.array(test_img))).round()
51
- mask = Image.fromarray(prediction[...,0].squeeze()*255).convert("L")
52
- return mask
 
53
 
54
- def replace_sofa(image:Image.Image, mask:Image.Image, styled_sofa:Image.Image) -> Image.Image:
55
  """
56
- This function replaces the original sofa in the image by the new styled sofa according
57
- to the mask.
58
  Remark: All images should have the same size.
59
  Input:
60
  image = Original image
@@ -63,11 +48,11 @@ def replace_sofa(image:Image.Image, mask:Image.Image, styled_sofa:Image.Image) -
63
  Return:
64
  new_image = Image containing the styled sofa
65
  """
66
- image,mask,styled_sofa = np.array(image),np.array(mask),np.array(styled_sofa)
67
 
68
  _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
69
  mask_inv = cv2.bitwise_not(mask)
70
- image_bg = cv2.bitwise_and(image,image,mask = mask_inv)
71
- sofa_fg = cv2.bitwise_and(styled_sofa,styled_sofa,mask = mask)
72
- new_image = cv2.add(image_bg,sofa_fg)
73
  return Image.fromarray(new_image)
 
1
+ # Import libraries
2
+
3
  import cv2
4
  from tensorflow import keras
5
  import numpy as np
6
  from PIL import Image
7
  import segmentation_models as sm
 
8
 
9
+ sm.set_framework("tf.keras")
 
 
 
10
 
11
+ # Load segmentation model
12
+ BACKBONE = "resnet50"
 
13
  preprocess_input = sm.get_preprocessing(BACKBONE)
14
+ model = keras.models.load_model("Segmentation/model_final.h5", compile=False)
15
+
 
 
 
 
 
 
 
 
 
 
16
 
17
+ def get_mask(image: Image) -> Image:
18
  """
19
+ This function generates a mask of the image that highlights all the sofas
20
+ in the image. This uses a pre-trained Unet model with a resnet50 backbone.
21
+ Remark: The model was trained on 640by640 images and it is therefore best
22
+ that the image has the same size.
23
+
24
  Parameters:
25
  image = original image
26
  Return:
27
  mask = corresponding maks of the image
28
  """
29
+ test_img = np.array(image)
 
 
 
 
 
30
  test_img = cv2.resize(test_img, (640, 640))
31
  test_img = cv2.cvtColor(test_img, cv2.COLOR_RGB2BGR)
32
  test_img = np.expand_dims(test_img, axis=0)
33
 
34
  prediction = model.predict(preprocess_input(np.array(test_img))).round()
35
+ mask = Image.fromarray(prediction[..., 0].squeeze() * 255).convert("L")
36
+ return mask
37
+
38
 
39
+ def replace_sofa(image: Image, mask: Image, styled_sofa: Image) -> Image:
40
  """
41
+ This function replaces the original sofa in the image by the new styled
42
+ sofa according to the mask.
43
  Remark: All images should have the same size.
44
  Input:
45
  image = Original image
 
48
  Return:
49
  new_image = Image containing the styled sofa
50
  """
51
+ image, mask, styled_sofa = np.array(image), np.array(mask), np.array(styled_sofa)
52
 
53
  _, mask = cv2.threshold(mask, 10, 255, cv2.THRESH_BINARY)
54
  mask_inv = cv2.bitwise_not(mask)
55
+ image_bg = cv2.bitwise_and(image, image, mask=mask_inv)
56
+ sofa_fg = cv2.bitwise_and(styled_sofa, styled_sofa, mask=mask)
57
+ new_image = cv2.add(image_bg, sofa_fg)
58
  return Image.fromarray(new_image)
StyleTransfer/{StyTR.py β†’ srcTransformer/StyTR.py} RENAMED
@@ -1,17 +1,24 @@
1
  import torch
2
  import torch.nn.functional as F
 
 
 
 
 
 
3
  from torch import nn
4
- from StyleTransfer.misc import (NestedTensor, nested_tensor_from_tensor_list,
5
- accuracy, get_world_size, interpolate,
6
- is_dist_avail_and_initialized)
7
- from StyleTransfer.function import normal,normal_style
8
- from StyleTransfer.function import calc_mean_std
9
- from StyleTransfer.ViT_helper import DropPath, to_2tuple, trunc_normal_
10
 
11
  class PatchEmbed(nn.Module):
12
- """ Image to Patch Embedding
13
- """
14
- def __init__(self, img_size=256, patch_size=8, in_chans=3, embed_dim=512):
 
 
 
 
 
 
15
  super().__init__()
16
  img_size = to_2tuple(img_size)
17
  patch_size = to_2tuple(patch_size)
@@ -19,9 +26,11 @@ class PatchEmbed(nn.Module):
19
  self.img_size = img_size
20
  self.patch_size = patch_size
21
  self.num_patches = num_patches
22
-
23
- self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
24
- self.up1 = nn.Upsample(scale_factor=2, mode='nearest')
 
 
25
 
26
  def forward(self, x):
27
  B, C, H, W = x.shape
@@ -34,7 +43,7 @@ decoder = nn.Sequential(
34
  nn.ReflectionPad2d((1, 1, 1, 1)),
35
  nn.Conv2d(512, 256, (3, 3)),
36
  nn.ReLU(),
37
- nn.Upsample(scale_factor=2, mode='nearest'),
38
  nn.ReflectionPad2d((1, 1, 1, 1)),
39
  nn.Conv2d(256, 256, (3, 3)),
40
  nn.ReLU(),
@@ -47,14 +56,14 @@ decoder = nn.Sequential(
47
  nn.ReflectionPad2d((1, 1, 1, 1)),
48
  nn.Conv2d(256, 128, (3, 3)),
49
  nn.ReLU(),
50
- nn.Upsample(scale_factor=2, mode='nearest'),
51
  nn.ReflectionPad2d((1, 1, 1, 1)),
52
  nn.Conv2d(128, 128, (3, 3)),
53
  nn.ReLU(),
54
  nn.ReflectionPad2d((1, 1, 1, 1)),
55
  nn.Conv2d(128, 64, (3, 3)),
56
  nn.ReLU(),
57
- nn.Upsample(scale_factor=2, mode='nearest'),
58
  nn.ReflectionPad2d((1, 1, 1, 1)),
59
  nn.Conv2d(64, 64, (3, 3)),
60
  nn.ReLU(),
@@ -115,26 +124,35 @@ vgg = nn.Sequential(
115
  nn.ReLU(), # relu5-3
116
  nn.ReflectionPad2d((1, 1, 1, 1)),
117
  nn.Conv2d(512, 512, (3, 3)),
118
- nn.ReLU() # relu5-4
119
  )
120
 
 
121
  class MLP(nn.Module):
122
- """ Very simple multi-layer perceptron (also called FFN)"""
123
 
124
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
 
 
125
  super().__init__()
126
  self.num_layers = num_layers
127
  h = [hidden_dim] * (num_layers - 1)
128
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
 
 
129
 
130
  def forward(self, x):
131
  for i, layer in enumerate(self.layers):
132
  x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
133
  return x
 
 
134
  class StyTrans(nn.Module):
135
- """ This is the style transform transformer module """
136
-
137
- def __init__(self,encoder,decoder,PatchEmbed, transformer):
 
 
138
 
139
  super().__init__()
140
  enc_layers = list(encoder.children())
@@ -143,85 +161,96 @@ class StyTrans(nn.Module):
143
  self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
144
  self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
145
  self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
146
-
147
- for name in ['enc_1', 'enc_2', 'enc_3', 'enc_4', 'enc_5']:
148
  for param in getattr(self, name).parameters():
149
  param.requires_grad = False
150
 
151
  self.mse_loss = nn.MSELoss()
152
  self.transformer = transformer
153
- hidden_dim = transformer.d_model
154
  self.decode = decoder
155
  self.embedding = PatchEmbed
156
 
157
  def encode_with_intermediate(self, input):
158
  results = [input]
159
  for i in range(5):
160
- func = getattr(self, 'enc_{:d}'.format(i + 1))
161
  results.append(func(results[-1]))
162
  return results[1:]
163
 
164
  def calc_content_loss(self, input, target):
165
- assert (input.size() == target.size())
166
- assert (target.requires_grad is False)
167
- return self.mse_loss(input, target)
168
 
169
  def calc_style_loss(self, input, target):
170
- assert (input.size() == target.size())
171
- assert (target.requires_grad is False)
172
  input_mean, input_std = calc_mean_std(input)
173
  target_mean, target_std = calc_mean_std(target)
174
- return self.mse_loss(input_mean, target_mean) + \
175
- self.mse_loss(input_std, target_std)
176
- def forward(self, samples_c: NestedTensor,samples_s: NestedTensor):
177
- """Β The forward expects a NestedTensor, which consists of:
178
- - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
179
- - samples.mask: a binary mask of shape [batch_size x H x W], containing 1 on padded pixels
 
 
 
180
 
181
  """
182
  content_input = samples_c
183
  style_input = samples_s
184
  if isinstance(samples_c, (list, torch.Tensor)):
185
- samples_c = nested_tensor_from_tensor_list(samples_c) # support different-sized images padding is used for mask [tensor, mask]
 
 
186
  if isinstance(samples_s, (list, torch.Tensor)):
187
- samples_s = nested_tensor_from_tensor_list(samples_s)
188
-
189
- # ### features used to calcate loss
190
  content_feats = self.encode_with_intermediate(samples_c.tensors)
191
  style_feats = self.encode_with_intermediate(samples_s.tensors)
192
 
193
- ### Linear projection
194
  style = self.embedding(samples_s.tensors)
195
  content = self.embedding(samples_c.tensors)
196
-
197
  # postional embedding is calculated in transformer.py
198
  pos_s = None
199
  pos_c = None
200
 
201
  mask = None
202
- hs = self.transformer(style, mask , content, pos_c, pos_s)
203
  Ics = self.decode(hs)
204
 
205
  Ics_feats = self.encode_with_intermediate(Ics)
206
- loss_c = self.calc_content_loss(normal(Ics_feats[-1]), normal(content_feats[-1]))+self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
 
 
207
  # Style loss
208
  loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
209
  for i in range(1, 5):
210
  loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
211
-
212
-
213
- Icc = self.decode(self.transformer(content, mask , content, pos_c, pos_c))
214
- Iss = self.decode(self.transformer(style, mask , style, pos_s, pos_s))
215
-
216
- #Identity losses lambda 1
217
- loss_lambda1 = self.calc_content_loss(Icc,content_input)+self.calc_content_loss(Iss,style_input)
218
-
219
- #Identity losses lambda 2
220
- Icc_feats=self.encode_with_intermediate(Icc)
221
- Iss_feats=self.encode_with_intermediate(Iss)
222
- loss_lambda2 = self.calc_content_loss(Icc_feats[0], content_feats[0])+self.calc_content_loss(Iss_feats[0], style_feats[0])
 
 
 
223
  for i in range(1, 5):
224
- loss_lambda2 += self.calc_content_loss(Icc_feats[i], content_feats[i])+self.calc_content_loss(Iss_feats[i], style_feats[i])
 
 
225
  # Please select and comment out one of the following two sentences
226
- return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 #train
227
- # return Ics #test
 
1
  import torch
2
  import torch.nn.functional as F
3
+ from StyleTransfer.srcTransformer.function import calc_mean_std, normal
4
+ from StyleTransfer.srcTransformer.misc import (
5
+ NestedTensor,
6
+ nested_tensor_from_tensor_list,
7
+ )
8
+ from StyleTransfer.srcTransformer.ViT_helper import to_2tuple
9
  from torch import nn
10
+
 
 
 
 
 
11
 
12
  class PatchEmbed(nn.Module):
13
+ """Image to Patch Embedding"""
14
+
15
+ def __init__(
16
+ self,
17
+ img_size: int = 256,
18
+ patch_size: int = 8,
19
+ in_chans: int = 3,
20
+ embed_dim: int = 512,
21
+ ):
22
  super().__init__()
23
  img_size = to_2tuple(img_size)
24
  patch_size = to_2tuple(patch_size)
 
26
  self.img_size = img_size
27
  self.patch_size = patch_size
28
  self.num_patches = num_patches
29
+
30
+ self.proj = nn.Conv2d(
31
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
32
+ )
33
+ self.up1 = nn.Upsample(scale_factor=2, mode="nearest")
34
 
35
  def forward(self, x):
36
  B, C, H, W = x.shape
 
43
  nn.ReflectionPad2d((1, 1, 1, 1)),
44
  nn.Conv2d(512, 256, (3, 3)),
45
  nn.ReLU(),
46
+ nn.Upsample(scale_factor=2, mode="nearest"),
47
  nn.ReflectionPad2d((1, 1, 1, 1)),
48
  nn.Conv2d(256, 256, (3, 3)),
49
  nn.ReLU(),
 
56
  nn.ReflectionPad2d((1, 1, 1, 1)),
57
  nn.Conv2d(256, 128, (3, 3)),
58
  nn.ReLU(),
59
+ nn.Upsample(scale_factor=2, mode="nearest"),
60
  nn.ReflectionPad2d((1, 1, 1, 1)),
61
  nn.Conv2d(128, 128, (3, 3)),
62
  nn.ReLU(),
63
  nn.ReflectionPad2d((1, 1, 1, 1)),
64
  nn.Conv2d(128, 64, (3, 3)),
65
  nn.ReLU(),
66
+ nn.Upsample(scale_factor=2, mode="nearest"),
67
  nn.ReflectionPad2d((1, 1, 1, 1)),
68
  nn.Conv2d(64, 64, (3, 3)),
69
  nn.ReLU(),
 
124
  nn.ReLU(), # relu5-3
125
  nn.ReflectionPad2d((1, 1, 1, 1)),
126
  nn.Conv2d(512, 512, (3, 3)),
127
+ nn.ReLU(), # relu5-4
128
  )
129
 
130
+
131
  class MLP(nn.Module):
132
+ """Very simple multi-layer perceptron (also called FFN)"""
133
 
134
+ def __init__(
135
+ self, input_dim: int, hidden_dim: int, output_dim: int, num_layers: int
136
+ ):
137
  super().__init__()
138
  self.num_layers = num_layers
139
  h = [hidden_dim] * (num_layers - 1)
140
+ self.layers = nn.ModuleList(
141
+ nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim])
142
+ )
143
 
144
  def forward(self, x):
145
  for i, layer in enumerate(self.layers):
146
  x = F.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
147
  return x
148
+
149
+
150
  class StyTrans(nn.Module):
151
+ """This is the style transform transformer module"""
152
+
153
+ def __init__(
154
+ self, encoder: nn.Sequential, decoder: nn.Sequential, PatchEmbed, transformer
155
+ ):
156
 
157
  super().__init__()
158
  enc_layers = list(encoder.children())
 
161
  self.enc_3 = nn.Sequential(*enc_layers[11:18]) # relu2_1 -> relu3_1
162
  self.enc_4 = nn.Sequential(*enc_layers[18:31]) # relu3_1 -> relu4_1
163
  self.enc_5 = nn.Sequential(*enc_layers[31:44]) # relu4_1 -> relu5_1
164
+
165
+ for name in ["enc_1", "enc_2", "enc_3", "enc_4", "enc_5"]:
166
  for param in getattr(self, name).parameters():
167
  param.requires_grad = False
168
 
169
  self.mse_loss = nn.MSELoss()
170
  self.transformer = transformer
 
171
  self.decode = decoder
172
  self.embedding = PatchEmbed
173
 
174
  def encode_with_intermediate(self, input):
175
  results = [input]
176
  for i in range(5):
177
+ func = getattr(self, "enc_{:d}".format(i + 1))
178
  results.append(func(results[-1]))
179
  return results[1:]
180
 
181
  def calc_content_loss(self, input, target):
182
+ assert input.size() == target.size()
183
+ assert target.requires_grad is False
184
+ return self.mse_loss(input, target)
185
 
186
  def calc_style_loss(self, input, target):
187
+ assert input.size() == target.size()
188
+ assert target.requires_grad is False
189
  input_mean, input_std = calc_mean_std(input)
190
  target_mean, target_std = calc_mean_std(target)
191
+ return self.mse_loss(input_mean, target_mean) + self.mse_loss(
192
+ input_std, target_std
193
+ )
194
+
195
+ def forward(self, samples_c: NestedTensor, samples_s: NestedTensor):
196
+ """The forward expects a NestedTensor, which consists of:
197
+ - samples.tensor: batched images, of shape [batch_size x 3 x H x W]
198
+ - samples.mask: a binary mask of shape [batch_size x H x W],
199
+ containing 1 on padded pixels
200
 
201
  """
202
  content_input = samples_c
203
  style_input = samples_s
204
  if isinstance(samples_c, (list, torch.Tensor)):
205
+ samples_c = nested_tensor_from_tensor_list(
206
+ samples_c
207
+ ) # support different-sized images padding is used for mask [tensor, mask]
208
  if isinstance(samples_s, (list, torch.Tensor)):
209
+ samples_s = nested_tensor_from_tensor_list(samples_s)
210
+
211
+ # features used to calcate loss
212
  content_feats = self.encode_with_intermediate(samples_c.tensors)
213
  style_feats = self.encode_with_intermediate(samples_s.tensors)
214
 
215
+ # Linear projection
216
  style = self.embedding(samples_s.tensors)
217
  content = self.embedding(samples_c.tensors)
218
+
219
  # postional embedding is calculated in transformer.py
220
  pos_s = None
221
  pos_c = None
222
 
223
  mask = None
224
+ hs = self.transformer(style, mask, content, pos_c, pos_s)
225
  Ics = self.decode(hs)
226
 
227
  Ics_feats = self.encode_with_intermediate(Ics)
228
+ loss_c = self.calc_content_loss(
229
+ normal(Ics_feats[-1]), normal(content_feats[-1])
230
+ ) + self.calc_content_loss(normal(Ics_feats[-2]), normal(content_feats[-2]))
231
  # Style loss
232
  loss_s = self.calc_style_loss(Ics_feats[0], style_feats[0])
233
  for i in range(1, 5):
234
  loss_s += self.calc_style_loss(Ics_feats[i], style_feats[i])
235
+
236
+ Icc = self.decode(self.transformer(content, mask, content, pos_c, pos_c))
237
+ Iss = self.decode(self.transformer(style, mask, style, pos_s, pos_s))
238
+
239
+ # Identity losses lambda 1
240
+ loss_lambda1 = self.calc_content_loss(
241
+ Icc, content_input
242
+ ) + self.calc_content_loss(Iss, style_input)
243
+
244
+ # Identity losses lambda 2
245
+ Icc_feats = self.encode_with_intermediate(Icc)
246
+ Iss_feats = self.encode_with_intermediate(Iss)
247
+ loss_lambda2 = self.calc_content_loss(
248
+ Icc_feats[0], content_feats[0]
249
+ ) + self.calc_content_loss(Iss_feats[0], style_feats[0])
250
  for i in range(1, 5):
251
+ loss_lambda2 += self.calc_content_loss(
252
+ Icc_feats[i], content_feats[i]
253
+ ) + self.calc_content_loss(Iss_feats[i], style_feats[i])
254
  # Please select and comment out one of the following two sentences
255
+ return Ics, loss_c, loss_s, loss_lambda1, loss_lambda2 # train
256
+ # return Ics #test
StyleTransfer/{models β†’ srcTransformer/Transformer_models}/decoder_iter_160000.pth RENAMED
File without changes
StyleTransfer/{models β†’ srcTransformer/Transformer_models}/embedding_iter_160000.pth RENAMED
File without changes
StyleTransfer/{models β†’ srcTransformer/Transformer_models}/transformer_iter_160000.pth RENAMED
File without changes
StyleTransfer/{models β†’ srcTransformer/Transformer_models}/vgg_normalised.pth RENAMED
File without changes
StyleTransfer/{ViT_helper.py β†’ srcTransformer/ViT_helper.py} RENAMED
@@ -1,18 +1,30 @@
 
 
 
 
1
  import torch
2
  from torch import nn
 
3
 
4
- def drop_path(x, drop_prob: float = 0., training: bool = False):
5
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
6
- This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
7
- the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
8
- See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
9
- changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
10
- 'survival rate' as the argument.
 
 
 
 
 
11
  """
12
- if drop_prob == 0. or not training:
13
  return x
14
  keep_prob = 1 - drop_prob
15
- shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
 
 
16
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
17
  random_tensor.floor_() # binarize
18
  output = x.div(keep_prob) * random_tensor
@@ -20,25 +32,26 @@ def drop_path(x, drop_prob: float = 0., training: bool = False):
20
 
21
 
22
  class DropPath(nn.Module):
23
- """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
24
  """
25
- def __init__(self, drop_prob=None):
 
 
 
 
26
  super(DropPath, self).__init__()
27
  self.drop_prob = drop_prob
28
 
29
  def forward(self, x):
30
  return drop_path(x, self.drop_prob, self.training)
31
 
32
- from itertools import repeat
33
- from torch._six import container_abcs
34
-
35
 
36
  # From PyTorch internals
37
- def _ntuple(n):
38
  def parse(x):
39
  if isinstance(x, container_abcs.Iterable):
40
  return x
41
  return tuple(repeat(x, n))
 
42
  return parse
43
 
44
 
@@ -48,41 +61,41 @@ to_3tuple = _ntuple(3)
48
  to_4tuple = _ntuple(4)
49
 
50
 
51
-
52
- import torch
53
- import math
54
- import warnings
55
-
56
-
57
- def _no_grad_trunc_normal_(tensor, mean, std, a, b):
58
- # Cut & paste from PyTorch official master until it's in a few official releases - RW
59
- # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
60
  def norm_cdf(x):
61
  # Computes standard normal cumulative distribution function
62
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
63
 
64
  if (mean < a - 2 * std) or (mean > b + 2 * std):
65
- warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
66
- "The distribution of values may be incorrect.",
67
- stacklevel=2)
 
 
68
 
69
  with torch.no_grad():
70
  # Values are generated by using a truncated uniform distribution and
71
  # then using the inverse CDF for the normal distribution.
72
  # Get upper and lower cdf values
73
- l = norm_cdf((a - mean) / std)
74
- u = norm_cdf((b - mean) / std)
75
 
76
  # Uniformly fill tensor with values from [l, u], then translate to
77
  # [2l-1, 2u-1].
78
- tensor.uniform_(2 * l - 1, 2 * u - 1)
79
 
80
  # Use inverse cdf transform for normal distribution to get truncated
81
  # standard normal
82
  tensor.erfinv_()
83
 
84
  # Transform to proper mean, std
85
- tensor.mul_(std * math.sqrt(2.))
86
  tensor.add_(mean)
87
 
88
  # Clamp to ensure it's in the proper range
@@ -90,8 +103,13 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
90
  return tensor
91
 
92
 
93
- def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
94
- # type: (Tensor, float, float, float, float) -> Tensor
 
 
 
 
 
95
  r"""Fills the input Tensor with values drawn from a truncated
96
  normal distribution. The values are effectively drawn from the
97
  normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
@@ -108,4 +126,4 @@ def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
108
  >>> w = torch.empty(3, 5)
109
  >>> nn.init.trunc_normal_(w)
110
  """
111
- return _no_grad_trunc_normal_(tensor, mean, std, a, b)
 
1
+ import math
2
+ import warnings
3
+ from itertools import repeat
4
+
5
  import torch
6
  from torch import nn
7
+ from torch._six import container_abcs
8
 
9
+
10
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
11
+ """
12
+ Drop paths (Stochastic Depth) per sample (when applied in main
13
+ path of residual blocks). This is the same as the DropConnect impl
14
+ I created for EfficientNet, etc networks, however, the original name
15
+ is misleading as 'Drop Connect' is a different form of dropout in a
16
+ separate paper... See discussion:
17
+ https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ...
18
+ I've opted for changing the layer and argument names to 'drop path'
19
+ rather than mix DropConnect as a layer name and use 'survival rate'
20
+ as the argument.
21
  """
22
+ if drop_prob == 0.0 or not training:
23
  return x
24
  keep_prob = 1 - drop_prob
25
+ shape = (x.shape[0],) + (1,) * (
26
+ x.ndim - 1
27
+ ) # work with diff dim tensors, not just 2D ConvNets
28
  random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
29
  random_tensor.floor_() # binarize
30
  output = x.div(keep_prob) * random_tensor
 
32
 
33
 
34
  class DropPath(nn.Module):
 
35
  """
36
+ Drop paths (Stochastic Depth) per sample
37
+ (when applied in main path of residual blocks).
38
+ """
39
+
40
+ def __init__(self, drop_prob: float = None):
41
  super(DropPath, self).__init__()
42
  self.drop_prob = drop_prob
43
 
44
  def forward(self, x):
45
  return drop_path(x, self.drop_prob, self.training)
46
 
 
 
 
47
 
48
  # From PyTorch internals
49
+ def _ntuple(n: int):
50
  def parse(x):
51
  if isinstance(x, container_abcs.Iterable):
52
  return x
53
  return tuple(repeat(x, n))
54
+
55
  return parse
56
 
57
 
 
61
  to_4tuple = _ntuple(4)
62
 
63
 
64
+ def _no_grad_trunc_normal_(
65
+ tensor: torch.tensor, mean: float, std: float, a: float, b: float
66
+ ):
67
+ # Cut & paste from PyTorch official master
68
+ # until it's in a few official releases - RW
69
+ # Method based on:
70
+ # https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
 
 
71
  def norm_cdf(x):
72
  # Computes standard normal cumulative distribution function
73
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
74
 
75
  if (mean < a - 2 * std) or (mean > b + 2 * std):
76
+ warnings.warn(
77
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
78
+ "The distribution of values may be incorrect.",
79
+ stacklevel=2,
80
+ )
81
 
82
  with torch.no_grad():
83
  # Values are generated by using a truncated uniform distribution and
84
  # then using the inverse CDF for the normal distribution.
85
  # Get upper and lower cdf values
86
+ lower = norm_cdf((a - mean) / std)
87
+ upper = norm_cdf((b - mean) / std)
88
 
89
  # Uniformly fill tensor with values from [l, u], then translate to
90
  # [2l-1, 2u-1].
91
+ tensor.uniform_(2 * lower - 1, 2 * upper - 1)
92
 
93
  # Use inverse cdf transform for normal distribution to get truncated
94
  # standard normal
95
  tensor.erfinv_()
96
 
97
  # Transform to proper mean, std
98
+ tensor.mul_(std * math.sqrt(2.0))
99
  tensor.add_(mean)
100
 
101
  # Clamp to ensure it's in the proper range
 
103
  return tensor
104
 
105
 
106
+ def trunc_normal_(
107
+ tensor: torch.tensor,
108
+ mean: float = 0.0,
109
+ std: float = 1.0,
110
+ a: float = -2.0,
111
+ b: float = 2.0,
112
+ ):
113
  r"""Fills the input Tensor with values drawn from a truncated
114
  normal distribution. The values are effectively drawn from the
115
  normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
 
126
  >>> w = torch.empty(3, 5)
127
  >>> nn.init.trunc_normal_(w)
128
  """
129
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
StyleTransfer/srcTransformer/__init__.py ADDED
File without changes
StyleTransfer/{function.py β†’ srcTransformer/function.py} RENAMED
@@ -4,35 +4,41 @@ import torch
4
  def calc_mean_std(feat, eps=1e-5):
5
  # eps is a small value added to the variance to avoid divide-by-zero.
6
  size = feat.size()
7
- assert (len(size) == 4)
8
  N, C = size[:2]
9
  feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
  feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
  feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
  return feat_mean, feat_std
13
 
 
14
  def calc_mean_std1(feat, eps=1e-5):
15
  # eps is a small value added to the variance to avoid divide-by-zero.
16
  size = feat.size()
17
  # assert (len(size) == 4)
18
- WH,N, C = size
19
  feat_var = feat.var(dim=0) + eps
20
  feat_std = feat_var.sqrt()
21
  feat_mean = feat.mean(dim=0)
22
  return feat_mean, feat_std
 
 
23
  def normal(feat, eps=1e-5):
24
- feat_mean, feat_std= calc_mean_std(feat, eps)
25
- normalized=(feat-feat_mean)/feat_std
26
- return normalized
 
 
27
  def normal_style(feat, eps=1e-5):
28
- feat_mean, feat_std= calc_mean_std1(feat, eps)
29
- normalized=(feat-feat_mean)/feat_std
30
  return normalized
31
 
 
32
  def _calc_feat_flatten_mean_std(feat):
33
  # takes 3D feat (C, H, W), return mean and std of array within channels
34
- assert (feat.size()[0] == 3)
35
- assert (isinstance(feat, torch.FloatTensor))
36
  feat_flatten = feat.view(3, -1)
37
  mean = feat_flatten.mean(dim=-1, keepdim=True)
38
  std = feat_flatten.std(dim=-1, keepdim=True)
@@ -49,25 +55,24 @@ def coral(source, target):
49
  # Note: flatten -> f
50
 
51
  source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
52
- source_f_norm = (source_f - source_f_mean.expand_as(
53
- source_f)) / source_f_std.expand_as(source_f)
54
- source_f_cov_eye = \
55
- torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
56
 
57
  target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
58
- target_f_norm = (target_f - target_f_mean.expand_as(
59
- target_f)) / target_f_std.expand_as(target_f)
60
- target_f_cov_eye = \
61
- torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
62
 
63
  source_f_norm_transfer = torch.mm(
64
  _mat_sqrt(target_f_cov_eye),
65
- torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)),
66
- source_f_norm)
67
  )
68
 
69
- source_f_transfer = source_f_norm_transfer * \
70
- target_f_std.expand_as(source_f_norm) + \
71
- target_f_mean.expand_as(source_f_norm)
72
 
73
  return source_f_transfer.view(source.size())
 
4
  def calc_mean_std(feat, eps=1e-5):
5
  # eps is a small value added to the variance to avoid divide-by-zero.
6
  size = feat.size()
7
+ assert len(size) == 4
8
  N, C = size[:2]
9
  feat_var = feat.view(N, C, -1).var(dim=2) + eps
10
  feat_std = feat_var.sqrt().view(N, C, 1, 1)
11
  feat_mean = feat.view(N, C, -1).mean(dim=2).view(N, C, 1, 1)
12
  return feat_mean, feat_std
13
 
14
+
15
  def calc_mean_std1(feat, eps=1e-5):
16
  # eps is a small value added to the variance to avoid divide-by-zero.
17
  size = feat.size()
18
  # assert (len(size) == 4)
19
+ WH, N, C = size
20
  feat_var = feat.var(dim=0) + eps
21
  feat_std = feat_var.sqrt()
22
  feat_mean = feat.mean(dim=0)
23
  return feat_mean, feat_std
24
+
25
+
26
  def normal(feat, eps=1e-5):
27
+ feat_mean, feat_std = calc_mean_std(feat, eps)
28
+ normalized = (feat - feat_mean) / feat_std
29
+ return normalized
30
+
31
+
32
  def normal_style(feat, eps=1e-5):
33
+ feat_mean, feat_std = calc_mean_std1(feat, eps)
34
+ normalized = (feat - feat_mean) / feat_std
35
  return normalized
36
 
37
+
38
  def _calc_feat_flatten_mean_std(feat):
39
  # takes 3D feat (C, H, W), return mean and std of array within channels
40
+ assert feat.size()[0] == 3
41
+ assert isinstance(feat, torch.FloatTensor)
42
  feat_flatten = feat.view(3, -1)
43
  mean = feat_flatten.mean(dim=-1, keepdim=True)
44
  std = feat_flatten.std(dim=-1, keepdim=True)
 
55
  # Note: flatten -> f
56
 
57
  source_f, source_f_mean, source_f_std = _calc_feat_flatten_mean_std(source)
58
+ source_f_norm = (
59
+ source_f - source_f_mean.expand_as(source_f)
60
+ ) / source_f_std.expand_as(source_f)
61
+ source_f_cov_eye = torch.mm(source_f_norm, source_f_norm.t()) + torch.eye(3)
62
 
63
  target_f, target_f_mean, target_f_std = _calc_feat_flatten_mean_std(target)
64
+ target_f_norm = (
65
+ target_f - target_f_mean.expand_as(target_f)
66
+ ) / target_f_std.expand_as(target_f)
67
+ target_f_cov_eye = torch.mm(target_f_norm, target_f_norm.t()) + torch.eye(3)
68
 
69
  source_f_norm_transfer = torch.mm(
70
  _mat_sqrt(target_f_cov_eye),
71
+ torch.mm(torch.inverse(_mat_sqrt(source_f_cov_eye)), source_f_norm),
 
72
  )
73
 
74
+ source_f_transfer = source_f_norm_transfer * target_f_std.expand_as(
75
+ source_f_norm
76
+ ) + target_f_mean.expand_as(source_f_norm)
77
 
78
  return source_f_transfer.view(source.size())
StyleTransfer/{misc.py β†’ srcTransformer/misc.py} RENAMED
@@ -4,20 +4,21 @@ Misc functions, including distributed helpers.
4
 
5
  Mostly copy-paste from torchvision references.
6
  """
 
7
  import os
 
8
  import subprocess
9
  import time
10
  from collections import defaultdict, deque
11
- import datetime
12
- import pickle
13
- from typing import Optional, List
14
 
15
  import torch
16
  import torch.distributed as dist
17
- from torch import Tensor
18
 
19
  # needed due to empty tensor bug in pytorch and torchvision 0.5
20
  import torchvision
 
 
21
  if float(torchvision.__version__[:3]) < 0.7:
22
  from torchvision.ops import _new_empty_tensor
23
  from torchvision.ops.misc import _output_size
@@ -47,7 +48,7 @@ class SmoothedValue(object):
47
  """
48
  if not is_dist_avail_and_initialized():
49
  return
50
- t = torch.tensor([self.count, self.total], dtype=torch.float64, device='cuda')
51
  dist.barrier()
52
  dist.all_reduce(t)
53
  t = t.tolist()
@@ -82,7 +83,8 @@ class SmoothedValue(object):
82
  avg=self.avg,
83
  global_avg=self.global_avg,
84
  max=self.max,
85
- value=self.value)
 
86
 
87
 
88
  def all_gather(data):
@@ -116,7 +118,9 @@ def all_gather(data):
116
  for _ in size_list:
117
  tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
118
  if local_size != max_size:
119
- padding = torch.empty(size=(max_size - local_size,), dtype=torch.uint8, device="cuda")
 
 
120
  tensor = torch.cat((tensor, padding), dim=0)
121
  dist.all_gather(tensor_list, tensor)
122
 
@@ -172,15 +176,14 @@ class MetricLogger(object):
172
  return self.meters[attr]
173
  if attr in self.__dict__:
174
  return self.__dict__[attr]
175
- raise AttributeError("'{}' object has no attribute '{}'".format(
176
- type(self).__name__, attr))
 
177
 
178
  def __str__(self):
179
  loss_str = []
180
  for name, meter in self.meters.items():
181
- loss_str.append(
182
- "{}: {}".format(name, str(meter))
183
- )
184
  return self.delimiter.join(loss_str)
185
 
186
  def synchronize_between_processes(self):
@@ -193,31 +196,35 @@ class MetricLogger(object):
193
  def log_every(self, iterable, print_freq, header=None):
194
  i = 0
195
  if not header:
196
- header = ''
197
  start_time = time.time()
198
  end = time.time()
199
- iter_time = SmoothedValue(fmt='{avg:.4f}')
200
- data_time = SmoothedValue(fmt='{avg:.4f}')
201
- space_fmt = ':' + str(len(str(len(iterable)))) + 'd'
202
  if torch.cuda.is_available():
203
- log_msg = self.delimiter.join([
204
- header,
205
- '[{0' + space_fmt + '}/{1}]',
206
- 'eta: {eta}',
207
- '{meters}',
208
- 'time: {time}',
209
- 'data: {data}',
210
- 'max mem: {memory:.0f}'
211
- ])
 
 
212
  else:
213
- log_msg = self.delimiter.join([
214
- header,
215
- '[{0' + space_fmt + '}/{1}]',
216
- 'eta: {eta}',
217
- '{meters}',
218
- 'time: {time}',
219
- 'data: {data}'
220
- ])
 
 
221
  MB = 1024.0 * 1024.0
222
  for obj in iterable:
223
  data_time.update(time.time() - end)
@@ -227,38 +234,54 @@ class MetricLogger(object):
227
  eta_seconds = iter_time.global_avg * (len(iterable) - i)
228
  eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
229
  if torch.cuda.is_available():
230
- print(log_msg.format(
231
- i, len(iterable), eta=eta_string,
232
- meters=str(self),
233
- time=str(iter_time), data=str(data_time),
234
- memory=torch.cuda.max_memory_allocated() / MB))
 
 
 
 
 
 
235
  else:
236
- print(log_msg.format(
237
- i, len(iterable), eta=eta_string,
238
- meters=str(self),
239
- time=str(iter_time), data=str(data_time)))
 
 
 
 
 
 
240
  i += 1
241
  end = time.time()
242
  total_time = time.time() - start_time
243
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
244
- print('{} Total time: {} ({:.4f} s / it)'.format(
245
- header, total_time_str, total_time / len(iterable)))
 
 
 
246
 
247
 
248
  def get_sha():
249
  cwd = os.path.dirname(os.path.abspath(__file__))
250
 
251
  def _run(command):
252
- return subprocess.check_output(command, cwd=cwd).decode('ascii').strip()
253
- sha = 'N/A'
 
254
  diff = "clean"
255
- branch = 'N/A'
256
  try:
257
- sha = _run(['git', 'rev-parse', 'HEAD'])
258
- subprocess.check_output(['git', 'diff'], cwd=cwd)
259
- diff = _run(['git', 'diff-index', 'HEAD'])
260
  diff = "has uncommited changes" if diff else "clean"
261
- branch = _run(['git', 'rev-parse', '--abbrev-ref', 'HEAD'])
262
  except Exception:
263
  pass
264
  message = f"sha: {sha}, status: {diff}, branch: {branch}"
@@ -324,9 +347,9 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
324
  mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
325
  for img, pad_img, m in zip(tensor_list, tensor, mask):
326
  pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
327
- m[: img.shape[1], :img.shape[2]] = False
328
  else:
329
- raise ValueError('not supported')
330
  return NestedTensor(tensor, mask)
331
 
332
 
@@ -336,7 +359,9 @@ def nested_tensor_from_tensor_list(tensor_list: List[Tensor]):
336
  def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
337
  max_size = []
338
  for i in range(tensor_list[0].dim()):
339
- max_size_i = torch.max(torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)).to(torch.int64)
 
 
340
  max_size.append(max_size_i)
341
  max_size = tuple(max_size)
342
 
@@ -348,11 +373,15 @@ def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTen
348
  padded_masks = []
349
  for img in tensor_list:
350
  padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
351
- padded_img = torch.nn.functional.pad(img, (0, padding[2], 0, padding[1], 0, padding[0]))
 
 
352
  padded_imgs.append(padded_img)
353
 
354
  m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
355
- padded_mask = torch.nn.functional.pad(m, (0, padding[2], 0, padding[1]), "constant", 1)
 
 
356
  padded_masks.append(padded_mask.to(torch.bool))
357
 
358
  tensor = torch.stack(padded_imgs)
@@ -366,10 +395,11 @@ def setup_for_distributed(is_master):
366
  This function disables printing when not in master process
367
  """
368
  import builtins as __builtin__
 
369
  builtin_print = __builtin__.print
370
 
371
  def print(*args, **kwargs):
372
- force = kwargs.pop('force', False)
373
  if is_master or force:
374
  builtin_print(*args, **kwargs)
375
 
@@ -406,26 +436,31 @@ def save_on_master(*args, **kwargs):
406
 
407
 
408
  def init_distributed_mode(args):
409
- if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ:
410
  args.rank = int(os.environ["RANK"])
411
- args.world_size = int(os.environ['WORLD_SIZE'])
412
- args.gpu = int(os.environ['LOCAL_RANK'])
413
- elif 'SLURM_PROCID' in os.environ:
414
- args.rank = int(os.environ['SLURM_PROCID'])
415
  args.gpu = args.rank % torch.cuda.device_count()
416
  else:
417
- print('Not using distributed mode')
418
  args.distributed = False
419
  return
420
 
421
  args.distributed = True
422
 
423
  torch.cuda.set_device(args.gpu)
424
- args.dist_backend = 'nccl'
425
- print('| distributed init (rank {}): {}'.format(
426
- args.rank, args.dist_url), flush=True)
427
- torch.distributed.init_process_group(backend=args.dist_backend, init_method=args.dist_url,
428
- world_size=args.world_size, rank=args.rank)
 
 
 
 
 
429
  torch.distributed.barrier()
430
  setup_for_distributed(args.rank == 0)
431
 
@@ -449,8 +484,14 @@ def accuracy(output, target, topk=(1,)):
449
  return res
450
 
451
 
452
- def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corners=None):
453
- # type: (Tensor, Optional[List[int]], Optional[float], str, Optional[bool]) -> Tensor
 
 
 
 
 
 
454
  """
455
  Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
456
  This will eventually be supported natively by PyTorch, and this
@@ -466,4 +507,6 @@ def interpolate(input, size=None, scale_factor=None, mode="nearest", align_corne
466
  output_shape = list(input.shape[:-2]) + list(output_shape)
467
  return _new_empty_tensor(input, output_shape)
468
  else:
469
- return torchvision.ops.misc.interpolate(input, size, scale_factor, mode, align_corners)
 
 
 
4
 
5
  Mostly copy-paste from torchvision references.
6
  """
7
+ import datetime
8
  import os
9
+ import pickle
10
  import subprocess
11
  import time
12
  from collections import defaultdict, deque
13
+ from typing import List, Optional
 
 
14
 
15
  import torch
16
  import torch.distributed as dist
 
17
 
18
  # needed due to empty tensor bug in pytorch and torchvision 0.5
19
  import torchvision
20
+ from torch import Tensor
21
+
22
  if float(torchvision.__version__[:3]) < 0.7:
23
  from torchvision.ops import _new_empty_tensor
24
  from torchvision.ops.misc import _output_size
 
48
  """
49
  if not is_dist_avail_and_initialized():
50
  return
51
+ t = torch.tensor([self.count, self.total], dtype=torch.float64, device="cuda")
52
  dist.barrier()
53
  dist.all_reduce(t)
54
  t = t.tolist()
 
83
  avg=self.avg,
84
  global_avg=self.global_avg,
85
  max=self.max,
86
+ value=self.value,
87
+ )
88
 
89
 
90
  def all_gather(data):
 
118
  for _ in size_list:
119
  tensor_list.append(torch.empty((max_size,), dtype=torch.uint8, device="cuda"))
120
  if local_size != max_size:
121
+ padding = torch.empty(
122
+ size=(max_size - local_size,), dtype=torch.uint8, device="cuda"
123
+ )
124
  tensor = torch.cat((tensor, padding), dim=0)
125
  dist.all_gather(tensor_list, tensor)
126
 
 
176
  return self.meters[attr]
177
  if attr in self.__dict__:
178
  return self.__dict__[attr]
179
+ raise AttributeError(
180
+ "'{}' object has no attribute '{}'".format(type(self).__name__, attr)
181
+ )
182
 
183
  def __str__(self):
184
  loss_str = []
185
  for name, meter in self.meters.items():
186
+ loss_str.append("{}: {}".format(name, str(meter)))
 
 
187
  return self.delimiter.join(loss_str)
188
 
189
  def synchronize_between_processes(self):
 
196
  def log_every(self, iterable, print_freq, header=None):
197
  i = 0
198
  if not header:
199
+ header = ""
200
  start_time = time.time()
201
  end = time.time()
202
+ iter_time = SmoothedValue(fmt="{avg:.4f}")
203
+ data_time = SmoothedValue(fmt="{avg:.4f}")
204
+ space_fmt = ":" + str(len(str(len(iterable)))) + "d"
205
  if torch.cuda.is_available():
206
+ log_msg = self.delimiter.join(
207
+ [
208
+ header,
209
+ "[{0" + space_fmt + "}/{1}]",
210
+ "eta: {eta}",
211
+ "{meters}",
212
+ "time: {time}",
213
+ "data: {data}",
214
+ "max mem: {memory:.0f}",
215
+ ]
216
+ )
217
  else:
218
+ log_msg = self.delimiter.join(
219
+ [
220
+ header,
221
+ "[{0" + space_fmt + "}/{1}]",
222
+ "eta: {eta}",
223
+ "{meters}",
224
+ "time: {time}",
225
+ "data: {data}",
226
+ ]
227
+ )
228
  MB = 1024.0 * 1024.0
229
  for obj in iterable:
230
  data_time.update(time.time() - end)
 
234
  eta_seconds = iter_time.global_avg * (len(iterable) - i)
235
  eta_string = str(datetime.timedelta(seconds=int(eta_seconds)))
236
  if torch.cuda.is_available():
237
+ print(
238
+ log_msg.format(
239
+ i,
240
+ len(iterable),
241
+ eta=eta_string,
242
+ meters=str(self),
243
+ time=str(iter_time),
244
+ data=str(data_time),
245
+ memory=torch.cuda.max_memory_allocated() / MB,
246
+ )
247
+ )
248
  else:
249
+ print(
250
+ log_msg.format(
251
+ i,
252
+ len(iterable),
253
+ eta=eta_string,
254
+ meters=str(self),
255
+ time=str(iter_time),
256
+ data=str(data_time),
257
+ )
258
+ )
259
  i += 1
260
  end = time.time()
261
  total_time = time.time() - start_time
262
  total_time_str = str(datetime.timedelta(seconds=int(total_time)))
263
+ print(
264
+ "{} Total time: {} ({:.4f} s / it)".format(
265
+ header, total_time_str, total_time / len(iterable)
266
+ )
267
+ )
268
 
269
 
270
  def get_sha():
271
  cwd = os.path.dirname(os.path.abspath(__file__))
272
 
273
  def _run(command):
274
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
275
+
276
+ sha = "N/A"
277
  diff = "clean"
278
+ branch = "N/A"
279
  try:
280
+ sha = _run(["git", "rev-parse", "HEAD"])
281
+ subprocess.check_output(["git", "diff"], cwd=cwd)
282
+ diff = _run(["git", "diff-index", "HEAD"])
283
  diff = "has uncommited changes" if diff else "clean"
284
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
285
  except Exception:
286
  pass
287
  message = f"sha: {sha}, status: {diff}, branch: {branch}"
 
347
  mask = torch.ones((b, h, w), dtype=torch.bool, device=device)
348
  for img, pad_img, m in zip(tensor_list, tensor, mask):
349
  pad_img[: img.shape[0], : img.shape[1], : img.shape[2]].copy_(img)
350
+ m[: img.shape[1], : img.shape[2]] = False
351
  else:
352
+ raise ValueError("not supported")
353
  return NestedTensor(tensor, mask)
354
 
355
 
 
359
  def _onnx_nested_tensor_from_tensor_list(tensor_list: List[Tensor]) -> NestedTensor:
360
  max_size = []
361
  for i in range(tensor_list[0].dim()):
362
+ max_size_i = torch.max(
363
+ torch.stack([img.shape[i] for img in tensor_list]).to(torch.float32)
364
+ ).to(torch.int64)
365
  max_size.append(max_size_i)
366
  max_size = tuple(max_size)
367
 
 
373
  padded_masks = []
374
  for img in tensor_list:
375
  padding = [(s1 - s2) for s1, s2 in zip(max_size, tuple(img.shape))]
376
+ padded_img = torch.nn.functional.pad(
377
+ img, (0, padding[2], 0, padding[1], 0, padding[0])
378
+ )
379
  padded_imgs.append(padded_img)
380
 
381
  m = torch.zeros_like(img[0], dtype=torch.int, device=img.device)
382
+ padded_mask = torch.nn.functional.pad(
383
+ m, (0, padding[2], 0, padding[1]), "constant", 1
384
+ )
385
  padded_masks.append(padded_mask.to(torch.bool))
386
 
387
  tensor = torch.stack(padded_imgs)
 
395
  This function disables printing when not in master process
396
  """
397
  import builtins as __builtin__
398
+
399
  builtin_print = __builtin__.print
400
 
401
  def print(*args, **kwargs):
402
+ force = kwargs.pop("force", False)
403
  if is_master or force:
404
  builtin_print(*args, **kwargs)
405
 
 
436
 
437
 
438
  def init_distributed_mode(args):
439
+ if "RANK" in os.environ and "WORLD_SIZE" in os.environ:
440
  args.rank = int(os.environ["RANK"])
441
+ args.world_size = int(os.environ["WORLD_SIZE"])
442
+ args.gpu = int(os.environ["LOCAL_RANK"])
443
+ elif "SLURM_PROCID" in os.environ:
444
+ args.rank = int(os.environ["SLURM_PROCID"])
445
  args.gpu = args.rank % torch.cuda.device_count()
446
  else:
447
+ print("Not using distributed mode")
448
  args.distributed = False
449
  return
450
 
451
  args.distributed = True
452
 
453
  torch.cuda.set_device(args.gpu)
454
+ args.dist_backend = "nccl"
455
+ print(
456
+ "| distributed init (rank {}): {}".format(args.rank, args.dist_url), flush=True
457
+ )
458
+ torch.distributed.init_process_group(
459
+ backend=args.dist_backend,
460
+ init_method=args.dist_url,
461
+ world_size=args.world_size,
462
+ rank=args.rank,
463
+ )
464
  torch.distributed.barrier()
465
  setup_for_distributed(args.rank == 0)
466
 
 
484
  return res
485
 
486
 
487
+ def interpolate(
488
+ input: torch.tensor,
489
+ size: List[int] = None,
490
+ scale_factor: float = None,
491
+ mode: str = "nearest",
492
+ align_corners: bool = None,
493
+ ) -> torch.tensor:
494
+
495
  """
496
  Equivalent to nn.functional.interpolate, but with support for empty batch sizes.
497
  This will eventually be supported natively by PyTorch, and this
 
507
  output_shape = list(input.shape[:-2]) + list(output_shape)
508
  return _new_empty_tensor(input, output_shape)
509
  else:
510
+ return torchvision.ops.misc.interpolate(
511
+ input, size, scale_factor, mode, align_corners
512
+ )
StyleTransfer/{transformer.py β†’ srcTransformer/transformer.py} RENAMED
@@ -1,40 +1,59 @@
1
  import copy
2
- from typing import Optional, List
 
3
 
 
4
  import torch
5
  import torch.nn.functional as F
6
- from torch import nn, Tensor
7
- from StyleTransfer.function import normal,normal_style
8
- import numpy as np
9
- import os
10
  device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
11
  os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
12
- class Transformer(nn.Module):
13
 
14
- def __init__(self, d_model=512, nhead=8, num_encoder_layers=3,
15
- num_decoder_layers=3, dim_feedforward=2048, dropout=0.1,
16
- activation="relu", normalize_before=False,
17
- return_intermediate_dec=False):
 
 
 
 
 
 
 
 
 
 
18
  super().__init__()
19
 
20
- encoder_layer = TransformerEncoderLayer(d_model, nhead, dim_feedforward,
21
- dropout, activation, normalize_before)
 
22
  encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
23
- self.encoder_c = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
24
- self.encoder_s = TransformerEncoder(encoder_layer, num_encoder_layers, encoder_norm)
25
-
26
- decoder_layer = TransformerDecoderLayer(d_model, nhead, dim_feedforward,
27
- dropout, activation, normalize_before)
 
 
 
 
 
28
  decoder_norm = nn.LayerNorm(d_model)
29
- self.decoder = TransformerDecoder(decoder_layer, num_decoder_layers, decoder_norm,
30
- return_intermediate=return_intermediate_dec)
 
 
 
 
31
 
32
  self._reset_parameters()
33
 
34
  self.d_model = d_model
35
  self.nhead = nhead
36
 
37
- self.new_ps = nn.Conv2d(512 , 512 , (1,1))
38
  self.averagepooling = nn.AdaptiveAvgPool2d(18)
39
 
40
  def _reset_parameters(self):
@@ -42,54 +61,64 @@ class Transformer(nn.Module):
42
  if p.dim() > 1:
43
  nn.init.xavier_uniform_(p)
44
 
45
- def forward(self, style, mask , content, pos_embed_c, pos_embed_s):
46
 
47
  # content-aware positional embedding
48
- content_pool = self.averagepooling(content)
49
  pos_c = self.new_ps(content_pool)
50
- pos_embed_c = F.interpolate(pos_c, mode='bilinear',size= style.shape[-2:])
51
 
52
- ###flatten NxCxHxW to HWxNxC
53
  style = style.flatten(2).permute(2, 0, 1)
54
  if pos_embed_s is not None:
55
  pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
56
-
57
  content = content.flatten(2).permute(2, 0, 1)
58
  if pos_embed_c is not None:
59
  pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
60
-
61
-
62
  style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
63
  content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
64
- hs = self.decoder(content, style, memory_key_padding_mask=mask,
65
- pos=pos_embed_s, query_pos=pos_embed_c)[0]
66
-
67
- ### HWxNxC to NxCxHxW to
68
- N, B, C= hs.shape
 
 
 
 
 
69
  H = int(np.sqrt(N))
70
  hs = hs.permute(1, 2, 0)
71
- hs = hs.view(B, C, -1,H)
72
 
73
  return hs
74
 
75
 
76
  class TransformerEncoder(nn.Module):
77
-
78
  def __init__(self, encoder_layer, num_layers, norm=None):
79
  super().__init__()
80
  self.layers = _get_clones(encoder_layer, num_layers)
81
  self.num_layers = num_layers
82
  self.norm = norm
83
 
84
- def forward(self, src,
85
- mask: Optional[Tensor] = None,
86
- src_key_padding_mask: Optional[Tensor] = None,
87
- pos: Optional[Tensor] = None):
 
 
 
88
  output = src
89
-
90
  for layer in self.layers:
91
- output = layer(output, src_mask=mask,
92
- src_key_padding_mask=src_key_padding_mask, pos=pos)
 
 
 
 
93
 
94
  if self.norm is not None:
95
  output = self.norm(output)
@@ -98,7 +127,6 @@ class TransformerEncoder(nn.Module):
98
 
99
 
100
  class TransformerDecoder(nn.Module):
101
-
102
  def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
103
  super().__init__()
104
  self.layers = _get_clones(decoder_layer, num_layers)
@@ -106,23 +134,32 @@ class TransformerDecoder(nn.Module):
106
  self.norm = norm
107
  self.return_intermediate = return_intermediate
108
 
109
- def forward(self, tgt, memory,
110
- tgt_mask: Optional[Tensor] = None,
111
- memory_mask: Optional[Tensor] = None,
112
- tgt_key_padding_mask: Optional[Tensor] = None,
113
- memory_key_padding_mask: Optional[Tensor] = None,
114
- pos: Optional[Tensor] = None,
115
- query_pos: Optional[Tensor] = None):
 
 
 
 
116
  output = tgt
117
 
118
  intermediate = []
119
 
120
  for layer in self.layers:
121
- output = layer(output, memory, tgt_mask=tgt_mask,
122
- memory_mask=memory_mask,
123
- tgt_key_padding_mask=tgt_key_padding_mask,
124
- memory_key_padding_mask=memory_key_padding_mask,
125
- pos=pos, query_pos=query_pos)
 
 
 
 
 
126
  if self.return_intermediate:
127
  intermediate.append(self.norm(output))
128
 
@@ -139,9 +176,15 @@ class TransformerDecoder(nn.Module):
139
 
140
 
141
  class TransformerEncoderLayer(nn.Module):
142
-
143
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
144
- activation="relu", normalize_before=False):
 
 
 
 
 
 
145
  super().__init__()
146
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
147
  # Implementation of Feedforward model
@@ -160,16 +203,19 @@ class TransformerEncoderLayer(nn.Module):
160
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
161
  return tensor if pos is None else tensor + pos
162
 
163
- def forward_post(self,
164
- src,
165
- src_mask: Optional[Tensor] = None,
166
- src_key_padding_mask: Optional[Tensor] = None,
167
- pos: Optional[Tensor] = None):
 
 
168
  q = k = self.with_pos_embed(src, pos)
169
  # q = k = src
170
  # print(q.size(),k.size(),src.size())
171
- src2 = self.self_attn(q, k, value=src, attn_mask=src_mask,
172
- key_padding_mask=src_key_padding_mask)[0]
 
173
  src = src + self.dropout1(src2)
174
  src = self.norm1(src)
175
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
@@ -177,33 +223,46 @@ class TransformerEncoderLayer(nn.Module):
177
  src = self.norm2(src)
178
  return src
179
 
180
- def forward_pre(self, src,
181
- src_mask: Optional[Tensor] = None,
182
- src_key_padding_mask: Optional[Tensor] = None,
183
- pos: Optional[Tensor] = None):
 
 
 
184
  src2 = self.norm1(src)
185
  q = k = self.with_pos_embed(src2, pos)
186
- src2 = self.self_attn(q, k, value=src2, attn_mask=src_mask,
187
- key_padding_mask=src_key_padding_mask)[0]
 
188
  src = src + self.dropout1(src2)
189
  src2 = self.norm2(src)
190
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
191
  src = src + self.dropout2(src2)
192
  return src
193
 
194
- def forward(self, src,
195
- src_mask: Optional[Tensor] = None,
196
- src_key_padding_mask: Optional[Tensor] = None,
197
- pos: Optional[Tensor] = None):
 
 
 
198
  if self.normalize_before:
199
  return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
200
  return self.forward_post(src, src_mask, src_key_padding_mask, pos)
201
 
202
 
203
  class TransformerDecoderLayer(nn.Module):
204
-
205
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1,
206
- activation="relu", normalize_before=False):
 
 
 
 
 
 
207
  super().__init__()
208
  # d_model embedding dim
209
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
@@ -226,28 +285,35 @@ class TransformerDecoderLayer(nn.Module):
226
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
227
  return tensor if pos is None else tensor + pos
228
 
229
- def forward_post(self, tgt, memory,
230
- tgt_mask: Optional[Tensor] = None,
231
- memory_mask: Optional[Tensor] = None,
232
- tgt_key_padding_mask: Optional[Tensor] = None,
233
- memory_key_padding_mask: Optional[Tensor] = None,
234
- pos: Optional[Tensor] = None,
235
- query_pos: Optional[Tensor] = None):
 
 
 
 
236
 
237
-
238
  q = self.with_pos_embed(tgt, query_pos)
239
  k = self.with_pos_embed(memory, pos)
240
- v = memory
241
-
242
- tgt2 = self.self_attn(q, k, v, attn_mask=tgt_mask,
243
- key_padding_mask=tgt_key_padding_mask)[0]
244
-
 
245
  tgt = tgt + self.dropout1(tgt2)
246
  tgt = self.norm1(tgt)
247
- tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt, query_pos),
248
- key=self.with_pos_embed(memory, pos),
249
- value=memory, attn_mask=memory_mask,
250
- key_padding_mask=memory_key_padding_mask)[0]
 
 
 
251
  tgt = tgt + self.dropout2(tgt2)
252
  tgt = self.norm2(tgt)
253
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
@@ -255,24 +321,32 @@ class TransformerDecoderLayer(nn.Module):
255
  tgt = self.norm3(tgt)
256
  return tgt
257
 
258
- def forward_pre(self, tgt, memory,
259
- tgt_mask: Optional[Tensor] = None,
260
- memory_mask: Optional[Tensor] = None,
261
- tgt_key_padding_mask: Optional[Tensor] = None,
262
- memory_key_padding_mask: Optional[Tensor] = None,
263
- pos: Optional[Tensor] = None,
264
- query_pos: Optional[Tensor] = None):
 
 
 
 
265
  tgt2 = self.norm1(tgt)
266
  q = k = self.with_pos_embed(tgt2, query_pos)
267
- tgt2 = self.self_attn(q, k, value=tgt2, attn_mask=tgt_mask,
268
- key_padding_mask=tgt_key_padding_mask)[0]
 
269
 
270
  tgt = tgt + self.dropout1(tgt2)
271
  tgt2 = self.norm2(tgt)
272
- tgt2 = self.multihead_attn(query=self.with_pos_embed(tgt2, query_pos),
273
- key=self.with_pos_embed(memory, pos),
274
- value=memory, attn_mask=memory_mask,
275
- key_padding_mask=memory_key_padding_mask)[0]
 
 
 
276
 
277
  tgt = tgt + self.dropout2(tgt2)
278
  tgt2 = self.norm3(tgt)
@@ -280,18 +354,38 @@ class TransformerDecoderLayer(nn.Module):
280
  tgt = tgt + self.dropout3(tgt2)
281
  return tgt
282
 
283
- def forward(self, tgt, memory,
284
- tgt_mask: Optional[Tensor] = None,
285
- memory_mask: Optional[Tensor] = None,
286
- tgt_key_padding_mask: Optional[Tensor] = None,
287
- memory_key_padding_mask: Optional[Tensor] = None,
288
- pos: Optional[Tensor] = None,
289
- query_pos: Optional[Tensor] = None):
 
 
 
 
290
  if self.normalize_before:
291
- return self.forward_pre(tgt, memory, tgt_mask, memory_mask,
292
- tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
293
- return self.forward_post(tgt, memory, tgt_mask, memory_mask,
294
- tgt_key_padding_mask, memory_key_padding_mask, pos, query_pos)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
295
 
296
 
297
  def _get_clones(module, N):
@@ -319,4 +413,4 @@ def _get_activation_fn(activation):
319
  return F.gelu
320
  if activation == "glu":
321
  return F.glu
322
- raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
 
1
  import copy
2
+ import os
3
+ from typing import Optional
4
 
5
+ import numpy as np
6
  import torch
7
  import torch.nn.functional as F
8
+ from torch import Tensor, nn
9
+
 
 
10
  device = torch.device("cuda:2" if torch.cuda.is_available() else "cpu")
11
  os.environ["CUDA_VISIBLE_DEVICES"] = "2, 3"
 
12
 
13
+
14
+ class Transformer(nn.Module):
15
+ def __init__(
16
+ self,
17
+ d_model=512,
18
+ nhead=8,
19
+ num_encoder_layers=3,
20
+ num_decoder_layers=3,
21
+ dim_feedforward=2048,
22
+ dropout=0.1,
23
+ activation="relu",
24
+ normalize_before=False,
25
+ return_intermediate_dec=False,
26
+ ):
27
  super().__init__()
28
 
29
+ encoder_layer = TransformerEncoderLayer(
30
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
31
+ )
32
  encoder_norm = nn.LayerNorm(d_model) if normalize_before else None
33
+ self.encoder_c = TransformerEncoder(
34
+ encoder_layer, num_encoder_layers, encoder_norm
35
+ )
36
+ self.encoder_s = TransformerEncoder(
37
+ encoder_layer, num_encoder_layers, encoder_norm
38
+ )
39
+
40
+ decoder_layer = TransformerDecoderLayer(
41
+ d_model, nhead, dim_feedforward, dropout, activation, normalize_before
42
+ )
43
  decoder_norm = nn.LayerNorm(d_model)
44
+ self.decoder = TransformerDecoder(
45
+ decoder_layer,
46
+ num_decoder_layers,
47
+ decoder_norm,
48
+ return_intermediate=return_intermediate_dec,
49
+ )
50
 
51
  self._reset_parameters()
52
 
53
  self.d_model = d_model
54
  self.nhead = nhead
55
 
56
+ self.new_ps = nn.Conv2d(512, 512, (1, 1))
57
  self.averagepooling = nn.AdaptiveAvgPool2d(18)
58
 
59
  def _reset_parameters(self):
 
61
  if p.dim() > 1:
62
  nn.init.xavier_uniform_(p)
63
 
64
+ def forward(self, style, mask, content, pos_embed_c, pos_embed_s):
65
 
66
  # content-aware positional embedding
67
+ content_pool = self.averagepooling(content)
68
  pos_c = self.new_ps(content_pool)
69
+ pos_embed_c = F.interpolate(pos_c, mode="bilinear", size=style.shape[-2:])
70
 
71
+ # flatten NxCxHxW to HWxNxC
72
  style = style.flatten(2).permute(2, 0, 1)
73
  if pos_embed_s is not None:
74
  pos_embed_s = pos_embed_s.flatten(2).permute(2, 0, 1)
75
+
76
  content = content.flatten(2).permute(2, 0, 1)
77
  if pos_embed_c is not None:
78
  pos_embed_c = pos_embed_c.flatten(2).permute(2, 0, 1)
79
+
 
80
  style = self.encoder_s(style, src_key_padding_mask=mask, pos=pos_embed_s)
81
  content = self.encoder_c(content, src_key_padding_mask=mask, pos=pos_embed_c)
82
+ hs = self.decoder(
83
+ content,
84
+ style,
85
+ memory_key_padding_mask=mask,
86
+ pos=pos_embed_s,
87
+ query_pos=pos_embed_c,
88
+ )[0]
89
+
90
+ # HWxNxC to NxCxHxW to
91
+ N, B, C = hs.shape
92
  H = int(np.sqrt(N))
93
  hs = hs.permute(1, 2, 0)
94
+ hs = hs.view(B, C, -1, H)
95
 
96
  return hs
97
 
98
 
99
  class TransformerEncoder(nn.Module):
 
100
  def __init__(self, encoder_layer, num_layers, norm=None):
101
  super().__init__()
102
  self.layers = _get_clones(encoder_layer, num_layers)
103
  self.num_layers = num_layers
104
  self.norm = norm
105
 
106
+ def forward(
107
+ self,
108
+ src,
109
+ mask: Optional[Tensor] = None,
110
+ src_key_padding_mask: Optional[Tensor] = None,
111
+ pos: Optional[Tensor] = None,
112
+ ):
113
  output = src
114
+
115
  for layer in self.layers:
116
+ output = layer(
117
+ output,
118
+ src_mask=mask,
119
+ src_key_padding_mask=src_key_padding_mask,
120
+ pos=pos,
121
+ )
122
 
123
  if self.norm is not None:
124
  output = self.norm(output)
 
127
 
128
 
129
  class TransformerDecoder(nn.Module):
 
130
  def __init__(self, decoder_layer, num_layers, norm=None, return_intermediate=False):
131
  super().__init__()
132
  self.layers = _get_clones(decoder_layer, num_layers)
 
134
  self.norm = norm
135
  self.return_intermediate = return_intermediate
136
 
137
+ def forward(
138
+ self,
139
+ tgt,
140
+ memory,
141
+ tgt_mask: Optional[Tensor] = None,
142
+ memory_mask: Optional[Tensor] = None,
143
+ tgt_key_padding_mask: Optional[Tensor] = None,
144
+ memory_key_padding_mask: Optional[Tensor] = None,
145
+ pos: Optional[Tensor] = None,
146
+ query_pos: Optional[Tensor] = None,
147
+ ):
148
  output = tgt
149
 
150
  intermediate = []
151
 
152
  for layer in self.layers:
153
+ output = layer(
154
+ output,
155
+ memory,
156
+ tgt_mask=tgt_mask,
157
+ memory_mask=memory_mask,
158
+ tgt_key_padding_mask=tgt_key_padding_mask,
159
+ memory_key_padding_mask=memory_key_padding_mask,
160
+ pos=pos,
161
+ query_pos=query_pos,
162
+ )
163
  if self.return_intermediate:
164
  intermediate.append(self.norm(output))
165
 
 
176
 
177
 
178
  class TransformerEncoderLayer(nn.Module):
179
+ def __init__(
180
+ self,
181
+ d_model,
182
+ nhead,
183
+ dim_feedforward=2048,
184
+ dropout=0.1,
185
+ activation="relu",
186
+ normalize_before=False,
187
+ ):
188
  super().__init__()
189
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
190
  # Implementation of Feedforward model
 
203
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
204
  return tensor if pos is None else tensor + pos
205
 
206
+ def forward_post(
207
+ self,
208
+ src,
209
+ src_mask: Optional[Tensor] = None,
210
+ src_key_padding_mask: Optional[Tensor] = None,
211
+ pos: Optional[Tensor] = None,
212
+ ):
213
  q = k = self.with_pos_embed(src, pos)
214
  # q = k = src
215
  # print(q.size(),k.size(),src.size())
216
+ src2 = self.self_attn(
217
+ q, k, value=src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
218
+ )[0]
219
  src = src + self.dropout1(src2)
220
  src = self.norm1(src)
221
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
 
223
  src = self.norm2(src)
224
  return src
225
 
226
+ def forward_pre(
227
+ self,
228
+ src,
229
+ src_mask: Optional[Tensor] = None,
230
+ src_key_padding_mask: Optional[Tensor] = None,
231
+ pos: Optional[Tensor] = None,
232
+ ):
233
  src2 = self.norm1(src)
234
  q = k = self.with_pos_embed(src2, pos)
235
+ src2 = self.self_attn(
236
+ q, k, value=src2, attn_mask=src_mask, key_padding_mask=src_key_padding_mask
237
+ )[0]
238
  src = src + self.dropout1(src2)
239
  src2 = self.norm2(src)
240
  src2 = self.linear2(self.dropout(self.activation(self.linear1(src2))))
241
  src = src + self.dropout2(src2)
242
  return src
243
 
244
+ def forward(
245
+ self,
246
+ src,
247
+ src_mask: Optional[Tensor] = None,
248
+ src_key_padding_mask: Optional[Tensor] = None,
249
+ pos: Optional[Tensor] = None,
250
+ ):
251
  if self.normalize_before:
252
  return self.forward_pre(src, src_mask, src_key_padding_mask, pos)
253
  return self.forward_post(src, src_mask, src_key_padding_mask, pos)
254
 
255
 
256
  class TransformerDecoderLayer(nn.Module):
257
+ def __init__(
258
+ self,
259
+ d_model,
260
+ nhead,
261
+ dim_feedforward=2048,
262
+ dropout=0.1,
263
+ activation="relu",
264
+ normalize_before=False,
265
+ ):
266
  super().__init__()
267
  # d_model embedding dim
268
  self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
 
285
  def with_pos_embed(self, tensor, pos: Optional[Tensor]):
286
  return tensor if pos is None else tensor + pos
287
 
288
+ def forward_post(
289
+ self,
290
+ tgt,
291
+ memory,
292
+ tgt_mask: Optional[Tensor] = None,
293
+ memory_mask: Optional[Tensor] = None,
294
+ tgt_key_padding_mask: Optional[Tensor] = None,
295
+ memory_key_padding_mask: Optional[Tensor] = None,
296
+ pos: Optional[Tensor] = None,
297
+ query_pos: Optional[Tensor] = None,
298
+ ):
299
 
 
300
  q = self.with_pos_embed(tgt, query_pos)
301
  k = self.with_pos_embed(memory, pos)
302
+ v = memory
303
+
304
+ tgt2 = self.self_attn(
305
+ q, k, v, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
306
+ )[0]
307
+
308
  tgt = tgt + self.dropout1(tgt2)
309
  tgt = self.norm1(tgt)
310
+ tgt2 = self.multihead_attn(
311
+ query=self.with_pos_embed(tgt, query_pos),
312
+ key=self.with_pos_embed(memory, pos),
313
+ value=memory,
314
+ attn_mask=memory_mask,
315
+ key_padding_mask=memory_key_padding_mask,
316
+ )[0]
317
  tgt = tgt + self.dropout2(tgt2)
318
  tgt = self.norm2(tgt)
319
  tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
 
321
  tgt = self.norm3(tgt)
322
  return tgt
323
 
324
+ def forward_pre(
325
+ self,
326
+ tgt,
327
+ memory,
328
+ tgt_mask: Optional[Tensor] = None,
329
+ memory_mask: Optional[Tensor] = None,
330
+ tgt_key_padding_mask: Optional[Tensor] = None,
331
+ memory_key_padding_mask: Optional[Tensor] = None,
332
+ pos: Optional[Tensor] = None,
333
+ query_pos: Optional[Tensor] = None,
334
+ ):
335
  tgt2 = self.norm1(tgt)
336
  q = k = self.with_pos_embed(tgt2, query_pos)
337
+ tgt2 = self.self_attn(
338
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
339
+ )[0]
340
 
341
  tgt = tgt + self.dropout1(tgt2)
342
  tgt2 = self.norm2(tgt)
343
+ tgt2 = self.multihead_attn(
344
+ query=self.with_pos_embed(tgt2, query_pos),
345
+ key=self.with_pos_embed(memory, pos),
346
+ value=memory,
347
+ attn_mask=memory_mask,
348
+ key_padding_mask=memory_key_padding_mask,
349
+ )[0]
350
 
351
  tgt = tgt + self.dropout2(tgt2)
352
  tgt2 = self.norm3(tgt)
 
354
  tgt = tgt + self.dropout3(tgt2)
355
  return tgt
356
 
357
+ def forward(
358
+ self,
359
+ tgt,
360
+ memory,
361
+ tgt_mask: Optional[Tensor] = None,
362
+ memory_mask: Optional[Tensor] = None,
363
+ tgt_key_padding_mask: Optional[Tensor] = None,
364
+ memory_key_padding_mask: Optional[Tensor] = None,
365
+ pos: Optional[Tensor] = None,
366
+ query_pos: Optional[Tensor] = None,
367
+ ):
368
  if self.normalize_before:
369
+ return self.forward_pre(
370
+ tgt,
371
+ memory,
372
+ tgt_mask,
373
+ memory_mask,
374
+ tgt_key_padding_mask,
375
+ memory_key_padding_mask,
376
+ pos,
377
+ query_pos,
378
+ )
379
+ return self.forward_post(
380
+ tgt,
381
+ memory,
382
+ tgt_mask,
383
+ memory_mask,
384
+ tgt_key_padding_mask,
385
+ memory_key_padding_mask,
386
+ pos,
387
+ query_pos,
388
+ )
389
 
390
 
391
  def _get_clones(module, N):
 
413
  return F.gelu
414
  if activation == "glu":
415
  return F.glu
416
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
StyleTransfer/styleTransfer.py CHANGED
@@ -1,115 +1,181 @@
1
- from PIL import Image
2
  import numpy as np
 
 
 
 
 
3
  import torch
4
- print(torch.cuda.is_available())
5
  import torch.nn as nn
 
6
  from torchvision import transforms
7
- import StyleTransfer.transformer as transformer
8
- import StyleTransfer.StyTR as StyTR
9
- from collections import OrderedDict
10
- import tensorflow_hub as tfhub
11
- import tensorflow as tf
12
- import paddlehub as phub
13
- import os
14
-
15
- ############################################# TRANSFORMER ############################################
16
 
17
- def style_transform(h:int,w:int) -> transforms.Compose:
18
- k = (h,w)
19
- transform_list = []
20
- transform_list.append(transforms.CenterCrop((h,w)))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  transform_list.append(transforms.ToTensor())
22
  transform = transforms.Compose(transform_list)
23
  return transform
24
 
25
- def content_transform() -> transforms.Compose:
26
- transform_list = []
 
 
 
 
 
 
 
 
27
  transform_list.append(transforms.ToTensor())
28
  transform = transforms.Compose(transform_list)
29
  return transform
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
32
- vgg_path = 'StyleTransfer/models/vgg_normalised.pth'
33
- decoder_path = 'StyleTransfer/models/decoder_iter_160000.pth'
34
- Trans_path = 'StyleTransfer/models/transformer_iter_160000.pth'
35
- embedding_path = 'StyleTransfer/models/embedding_iter_160000.pth'
36
- # Advanced options
37
- content_size=640
38
- style_size=640
39
-
40
- vgg = StyTR.vgg
41
- vgg.load_state_dict(torch.load(vgg_path))
42
- vgg = nn.Sequential(*list(vgg.children())[:44])
43
-
44
- decoder = StyTR.decoder
45
- Trans = transformer.Transformer()
46
- embedding = StyTR.PatchEmbed()
47
  decoder.eval()
48
  Trans.eval()
49
  vgg.eval()
50
 
51
- new_state_dict = OrderedDict()
52
  state_dict = torch.load(decoder_path)
53
  decoder.load_state_dict(state_dict)
54
 
55
- new_state_dict = OrderedDict()
56
  state_dict = torch.load(Trans_path)
57
  Trans.load_state_dict(state_dict)
58
 
59
- new_state_dict = OrderedDict()
60
  state_dict = torch.load(embedding_path)
61
  embedding.load_state_dict(state_dict)
62
 
63
- network = StyTR.StyTrans(vgg,decoder,embedding,Trans)
64
  network.eval()
65
- content_tf = content_transform()
66
- style_tf = style_transform(style_size,style_size)
 
67
 
68
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
  network.to(device)
70
- content = content_tf(content_img.convert("RGB"))
71
  style = style_tf(style_img.convert("RGB"))
72
  style = style.to(device).unsqueeze(0)
73
  content = content.to(device).unsqueeze(0)
74
  with torch.no_grad():
75
- output= network(content,style)
76
  output = output[0].cpu().squeeze()
77
- output = output.mul(255).add_(0.5).clamp_(0, 255).permute(1, 2, 0).to('cpu', torch.uint8).numpy()
 
 
 
 
 
 
 
78
  return Image.fromarray(output)
79
-
80
- ############################################## STYLE-FAST #############################################
81
- style_transfer_model = tfhub.load("https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2")
82
 
83
- def StyleFAST(content_image:Image.Image, style_image:Image.Image) -> Image.Image:
84
- content_image = tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...] / 255.
85
- style_image = tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
  output = style_transfer_model(content_image, style_image)
87
  stylized_image = output[0]
88
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
89
 
90
- ########################################### STYLE PROJECTION ##########################################
91
- os.system("hub install stylepro_artistic==1.0.1")
 
92
  stylepro_artistic = phub.Module(name="stylepro_artistic")
93
- def StyleProjection(content_image:Image.Image,style_image:Image.Image) -> Image.Image:
94
- print('line92')
95
- result = stylepro_artistic.style_transfer(
96
- images=[{
97
- 'content': np.array(content_image.convert('RGB') )[:, :, ::-1],
98
- 'styles': [np.array(style_image.convert('RGB') )[:, :, ::-1]]}],
99
- alpha=0.8)
100
- print('line97')
101
- return Image.fromarray(np.uint8(result[0]['data'])[:,:,::-1]).convert('RGB')
102
-
103
- def create_styledSofa(content_image:Image.Image,style_image:Image.Image,choice:str) -> Image.Image:
104
- if choice =="Style Transformer":
105
- output = StyleTransformer(content_image,style_image)
106
- elif choice =="Style FAST":
107
- output = StyleFAST(content_image,style_image)
108
- elif choice =="Style Projection":
109
- output = StyleProjection(content_image,style_image)
110
- else:
111
- output = content_image
112
- return output
113
-
114
 
115
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
+ import paddlehub as phub
3
+ import StyleTransfer.srcTransformer.StyTR as StyTR
4
+ import StyleTransfer.srcTransformer.transformer as transformer
5
+ import tensorflow as tf
6
+ import tensorflow_hub as tfhub
7
  import torch
 
8
  import torch.nn as nn
9
+ from PIL import Image
10
  from torchvision import transforms
 
 
 
 
 
 
 
 
 
11
 
12
+ # TRANSFORMER
13
+
14
+ vgg_path = "StyleTransfer/srcTransformer/Transformer_models/vgg_normalised.pth"
15
+ decoder_path = "StyleTransfer/srcTransformer/Transformer_models/decoder_iter_160000.pth"
16
+ Trans_path = (
17
+ "StyleTransfer/srcTransformer/Transformer_models/transformer_iter_160000.pth"
18
+ )
19
+ embedding_path = (
20
+ "StyleTransfer/srcTransformer/Transformer_models/embedding_iter_160000.pth"
21
+ )
22
+
23
+
24
+ def style_transform(h, w):
25
+ """
26
+ This function creates a transformation for the style image,
27
+ that crops it and formats it into a tensor.
28
+
29
+ Parameters:
30
+ h = height
31
+ w = width
32
+ Return:
33
+ transform = transformation pipeline
34
+ """
35
+ transform_list = []
36
+ transform_list.append(transforms.CenterCrop((h, w)))
37
  transform_list.append(transforms.ToTensor())
38
  transform = transforms.Compose(transform_list)
39
  return transform
40
 
41
+
42
+ def content_transform():
43
+ """
44
+ This function simply creates a transformation pipeline,
45
+ that formats the content image into a tensor.
46
+
47
+ Returns:
48
+ transform = the transformation pipeline
49
+ """
50
+ transform_list = []
51
  transform_list.append(transforms.ToTensor())
52
  transform = transforms.Compose(transform_list)
53
  return transform
54
 
55
+
56
+ # This loads the network architecture already at building time
57
+ vgg = StyTR.vgg
58
+ vgg.load_state_dict(torch.load(vgg_path))
59
+ vgg = nn.Sequential(*list(vgg.children())[:44])
60
+ decoder = StyTR.decoder
61
+ Trans = transformer.Transformer()
62
+ embedding = StyTR.PatchEmbed()
63
+ # The (square) shape of the content and style image is fixed
64
+ content_size = 640
65
+ style_size = 640
66
+
67
+
68
  def StyleTransformer(content_img: Image.Image, style_img: Image.Image) -> Image.Image:
69
+ """
70
+ This function creates the Transformer network and applies it on
71
+ a content and style image to create a styled image.
72
+
73
+ Parameters:
74
+ content_img = the image with the content
75
+ style_img = the image with the style/pattern
76
+ Returns:
77
+ output = an image that is a combination of both
78
+ """
79
+
 
 
 
 
80
  decoder.eval()
81
  Trans.eval()
82
  vgg.eval()
83
 
 
84
  state_dict = torch.load(decoder_path)
85
  decoder.load_state_dict(state_dict)
86
 
 
87
  state_dict = torch.load(Trans_path)
88
  Trans.load_state_dict(state_dict)
89
 
 
90
  state_dict = torch.load(embedding_path)
91
  embedding.load_state_dict(state_dict)
92
 
93
+ network = StyTR.StyTrans(vgg, decoder, embedding, Trans)
94
  network.eval()
95
+
96
+ content_tf = content_transform()
97
+ style_tf = style_transform(style_size, style_size)
98
 
99
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
100
  network.to(device)
101
+ content = content_tf(content_img.convert("RGB"))
102
  style = style_tf(style_img.convert("RGB"))
103
  style = style.to(device).unsqueeze(0)
104
  content = content.to(device).unsqueeze(0)
105
  with torch.no_grad():
106
+ output = network(content, style)
107
  output = output[0].cpu().squeeze()
108
+ output = (
109
+ output.mul(255)
110
+ .add_(0.5)
111
+ .clamp_(0, 255)
112
+ .permute(1, 2, 0)
113
+ .to("cpu", torch.uint8)
114
+ .numpy()
115
+ )
116
  return Image.fromarray(output)
 
 
 
117
 
118
+
119
+ # STYLE-FAST
120
+
121
+ style_transfer_model = tfhub.load(
122
+ "https://tfhub.dev/google/magenta/arbitrary-image-stylization-v1-256/2"
123
+ )
124
+
125
+
126
+ def StyleFAST(content_image: Image.Image, style_image: Image.Image) -> Image.Image:
127
+ """
128
+ This function applies a Fast image style transfer technique,
129
+ which uses a pretrained model from tensorhub.
130
+
131
+ Parameters:
132
+ content_image = the image with the content
133
+ style_image = the image with the style/pattern
134
+ Returns:
135
+ stylized_image = an image that is a combination of both
136
+ """
137
+ content_image = (
138
+ tf.convert_to_tensor(np.array(content_image), np.float32)[tf.newaxis, ...]
139
+ / 255.0
140
+ )
141
+ style_image = (
142
+ tf.convert_to_tensor(np.array(style_image), np.float32)[tf.newaxis, ...] / 255.0
143
+ )
144
  output = style_transfer_model(content_image, style_image)
145
  stylized_image = output[0]
146
  return Image.fromarray(np.uint8(stylized_image[0] * 255))
147
 
148
+
149
+ # STYLE PROJECTION
150
+
151
  stylepro_artistic = phub.Module(name="stylepro_artistic")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
 
154
+ def styleProjection(
155
+ content_image: Image.Image, style_image: Image.Image, alpha: float = 1.0
156
+ ):
157
+ """
158
+ This function uses parameter free style transfer,
159
+ based on a model from paddlehub.
160
+ There is an optional weight parameter alpha, which
161
+ allows to control the balance between image and style.
162
+
163
+ Parameters:
164
+ content_image = the image with the content
165
+ style_image = the image with the style/pattern
166
+ alpha = weight for the image vs style.
167
+ This should be a float between 0 and 1.
168
+ Returns:
169
+ result = an image that is a combination of both
170
+ """
171
+ result = stylepro_artistic.style_transfer(
172
+ images=[
173
+ {
174
+ "content": np.array(content_image.convert("RGB"))[:, :, ::-1],
175
+ "styles": [np.array(style_image.convert("RGB"))[:, :, ::-1]],
176
+ }
177
+ ],
178
+ alpha=alpha,
179
+ )
180
+
181
+ return Image.fromarray(np.uint8(result[0]["data"])[:, :, ::-1]).convert("RGB")
app.py CHANGED
@@ -1,50 +1,71 @@
1
- from random import randint
2
- from typing import Tuple
3
-
4
- import gradio as gr
5
  import numpy as np
6
- from PIL import Image, ExifTags
7
-
8
  from Segmentation.segmentation import get_mask, replace_sofa
9
- from StyleTransfer.styleTransfer import create_styledSofa
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  def fix_orient(img: Image.Image) -> Image.Image:
 
 
 
 
 
 
 
 
 
 
12
  for orientation in ExifTags.TAGS.keys():
13
- if ExifTags.TAGS[orientation]=='Orientation':
 
14
  break
15
- info=img.getexif()
16
- if (info):
 
17
  info = dict(info.items())
18
- orientation = info[orientation]
19
- if (orientation == 1) | (orientation == 2):
20
- img = img
21
- if (orientation == 3) | (orientation == 4):
22
- img = img.rotate(180,expand=True)
23
- if (orientation == 5) | (orientation == 6):
24
- img = img.rotate(270,expand=True)
25
- if (orientation == 7) | (orientation == 8):
26
- img = img.rotate(90,expand=True)
 
27
  return img
28
 
29
 
30
- def resize_sofa(img: Image.Image) -> Tuple[Image.Image, tuple]:
31
  """
32
- This function adds padding to make the original image square
33
- and 640by640. It also returns the original ratio of the image,
34
- such that it can be reverted later.
35
  Parameters:
36
  img = original image
37
  Return:
38
- im1 = squared image
39
- box = parameters to later crop the image to it original ratio
40
  """
41
  width, height = img.size
42
  idx = np.argmin([width, height])
43
  newsize = (640, 640) # parameters from test script
44
 
45
  if idx == 0:
46
- img1 = Image.new(img.mode, (height, height), (255, 255, 255))
47
- img1.paste(img, ((height - width) // 2, 0))
48
  box = (
49
  newsize[0] * (1 - width / height) // 2,
50
  0,
@@ -52,22 +73,22 @@ def resize_sofa(img: Image.Image) -> Tuple[Image.Image, tuple]:
52
  newsize[1],
53
  )
54
  else:
55
- img1 = Image.new(img.mode, (width, width), (255, 255, 255))
56
- img1.paste(img, (0, (width - height) // 2))
57
  box = (
58
  0,
59
  newsize[1] * (1 - height / width) // 2,
60
  newsize[0],
61
  newsize[1] - newsize[1] * (1 - height / width) // 2,
62
  )
63
- im1 = img1.resize(newsize)
64
- return im1, box
65
 
66
 
67
  def resize_style(img: Image.Image) -> Image.Image:
68
  """
69
- This function generates a zoomed out version of the style
70
- image and resizes it to a 640by640 square.
71
  Parameters:
72
  img = image containing the style/pattern
73
  Return:
@@ -88,114 +109,249 @@ def resize_style(img: Image.Image) -> Image.Image:
88
  top = 0
89
  bottom = height
90
  newsize = (640, 640) # parameters from test script
91
- im1 = img.crop((left, top, right, bottom))
92
 
93
  # Constructs a zoomed-out version
94
  copies = 8
95
  resize = (newsize[0] // copies, newsize[1] // copies)
96
- dst = Image.new("RGB", (resize[0] * copies, resize[1] * copies))
97
- im2 = im1.resize((resize))
98
  for row in range(copies):
99
- im2 = im2.transpose(Image.FLIP_LEFT_RIGHT)
100
  for column in range(copies):
101
- im2 = im2.transpose(Image.FLIP_TOP_BOTTOM)
102
- dst.paste(im2, (resize[0] * row, resize[1] * column))
103
- dst = dst.resize((newsize))
104
- return dst
105
 
106
 
107
- def style_sofa(
108
- Input_image: Image.Image, Style_image: Image.Image, Choice_of_algorithm: str
109
- ) -> Image.Image:
110
- """
111
- Styles (all) the sofas in the image to the given style.
112
- This function uses a transformer to combine the image with
113
- the desired style according to a generated mask of the sofas
114
- in the image.
115
- Input:
116
- input_img = image containing a sofa
117
- style_img = image containing a style
118
- choice = Style transfer algorithm to use
119
- Return:
120
- new_sofa = image containing the styled sofa
121
- """
122
- id = randint(0, 10)
123
- print("Starting job ", id)
124
- # preprocess input images to fit requirements of the segmentation model
125
- resized_img, box = resize_sofa(fix_orient(Input_image))
126
- resized_style = resize_style(fix_orient(Style_image))
127
- # resized_style.save('resized_style.jpg')
128
- # generate mask for image
129
- print("generating mask...")
130
- mask = get_mask(resized_img)
131
- # mask.save('mask.jpg')
132
- # Created a styled sofa
133
- print("Styling sofa...")
134
- styled_sofa = create_styledSofa(resized_img, resized_style, Choice_of_algorithm)
135
- # styled_sofa.save('styled_sofa.jpg')
136
- # postprocess the final image
137
- print("Replacing sofa...")
138
- new_sofa = replace_sofa(resized_img, mask, styled_sofa)
139
- new_sofa = new_sofa.crop(box)
140
- print("Finishing job", id)
141
- return new_sofa
142
-
143
-
144
- demo = gr.Interface(
145
- style_sofa,
146
- inputs=[
147
- gr.inputs.Image(type="pil"),
148
- gr.inputs.Image(type="pil"),
149
- gr.inputs.Radio(
150
- ["Style Transformer", "Style FAST", "Style Projection"],
151
- default="Style FAST",
152
- ),
153
- ],
154
- outputs="image",
155
- examples=[
156
- [
157
- "figures/sofa_example1.jpg",
158
- "figures/style_example1.jpg",
159
- "Style Transformer",
160
- ],
161
- [
162
- "figures/sofa_example3.jpg",
163
- "figures/style_example10.jpg",
164
- "Style FAST",
165
- ],
166
- [
167
- "figures/sofa_example2.jpg",
168
- "figures/style_example6.jpg",
169
- "Style Projection",
170
- ],
171
- ],
172
- title="πŸ›‹ Style your sofa πŸ›‹ ",
173
- description="Customize your sofa to your wildest dreams πŸ’­!\
174
- \nProvide a picture of your sofa, a desired pattern\
175
- and (optionally) choose one of the algorithms.\
176
- \nOr just pick one of the examples below. ⬇",
177
- theme="huggingface",
178
- enable_queue=True,
179
- article="**References**\n\n"
180
- "<a href='https://tianchi.aliyun.com/specials/promotion/alibaba-3d-future' \