sunshineatnoon commited on
Commit
827b81f
1 Parent(s): 733ce12
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. app.py +18 -18
  2. data/images/108073.jpg +0 -0
  3. data/{test_images → images}/130014.jpg +0 -0
  4. data/images/134008.jpg +0 -0
  5. data/images/25098.jpg +0 -0
  6. data/images/45077.jpg +0 -0
  7. data/images/corn.jpg +0 -0
  8. data/test_images/108004.jpg +0 -0
  9. data/test_images/130066.jpg +0 -0
  10. data/test_images/16068.jpg +0 -0
  11. data/test_images/208078.jpg +0 -0
  12. data/test_images/223060.jpg +0 -0
  13. data/test_images/388006.jpg +0 -0
  14. data/test_images/78098.jpg +0 -0
  15. models/week0417/__pycache__/loss.cpython-38.pyc +0 -0
  16. models/week0417/__pycache__/meanshift_utils.cpython-37.pyc +0 -0
  17. models/week0417/__pycache__/meanshift_utils.cpython-38.pyc +0 -0
  18. models/week0417/__pycache__/model.cpython-37.pyc +0 -0
  19. models/week0417/__pycache__/model.cpython-38.pyc +0 -0
  20. models/week0417/__pycache__/nnutils.cpython-38.pyc +0 -0
  21. models/week0417/__pycache__/taming_blocks.cpython-38.pyc +0 -0
  22. models/week0417/meanshift_utils.py +62 -0
  23. models/week0417/model.py +220 -43
  24. models/week0417/model_bk.py +204 -0
  25. swapae/models/__pycache__/__init__.cpython-38.pyc +0 -0
  26. swapae/models/__pycache__/base_model.cpython-38.pyc +0 -0
  27. swapae/models/networks/__pycache__/__init__.cpython-38.pyc +0 -0
  28. swapae/models/networks/__pycache__/base_network.cpython-38.pyc +0 -0
  29. swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc +0 -0
  30. swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc +0 -0
  31. swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc +0 -0
  32. swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc +0 -0
  33. swapae/util/__pycache__/__init__.cpython-38.pyc +0 -0
  34. swapae/util/__pycache__/html.cpython-38.pyc +0 -0
  35. swapae/util/__pycache__/iter_counter.cpython-38.pyc +0 -0
  36. swapae/util/__pycache__/metric_tracker.cpython-38.pyc +0 -0
  37. swapae/util/__pycache__/util.cpython-38.pyc +0 -0
  38. swapae/util/__pycache__/visualizer.cpython-38.pyc +0 -0
  39. tmp/0.png +0 -0
  40. tmp/1.png +0 -0
  41. tmp/2.png +0 -0
  42. tmp/3.png +0 -0
  43. tmp/4.png +0 -0
  44. tmp/5.png +0 -0
  45. tmp/6.png +0 -0
  46. tmp/7.png +0 -0
  47. tmp/8.png +0 -0
  48. tmp/9.png +0 -0
  49. tmp/tmp.png +0 -0
  50. weights/108004/exp_args.json +0 -63
app.py CHANGED
@@ -167,15 +167,15 @@ class Tester(TesterBase):
167
 
168
  sampled_code = tex_seg[:, tex_idx, :]
169
  rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
170
- sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
171
  H = tex_size // 8; W = tex_size // 8
172
  noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
173
  dec_input = torch.cat((sine_wave, noise), dim = 1)
174
 
175
- weight = self.model.ChannelWeight(rec_tex)
176
- weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
177
- weight = torch.sigmoid(weight)
178
- dec_input *= weight
179
 
180
  rep_rec = self.model.G(dec_input, rec_tex)
181
  rep_rec = (rep_rec + 1) / 2.0
@@ -256,6 +256,7 @@ class Tester(TesterBase):
256
  given_mask = torch.from_numpy(given_mask)
257
  H, W = given_mask.shape[0], given_mask.shape[1]
258
  given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1))
 
259
  mask_img_list = []
260
  for i in range(given_mask.shape[1]):
261
  part_img = self.to_pil(given_mask[0, i])
@@ -291,22 +292,26 @@ class Tester(TesterBase):
291
  #tex_size = (tex_size // 8) * 8
292
  with torch.no_grad():
293
  edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited')
294
- col1, col2, col3, col4 = st.columns([1, 1, 4, 1])
295
  with col1:
296
- st.markdown("")
297
-
298
- with col2:
299
  st.markdown("Input image")
300
  img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False)
301
  st.image(self.to_pil((img + 1) / 2.0))
302
  print(img.shape, edited.shape)
303
 
 
 
 
 
 
 
 
 
 
304
  with col3:
305
  st.markdown("Synthesized texture image")
306
  st.image(self.to_pil(edited))
307
 
308
- with col4:
309
- st.markdown("")
310
  st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
311
 
312
  def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False,
@@ -334,23 +339,17 @@ class Tester(TesterBase):
334
  tex_seg = poolfeat(conv_feats, softmax, avg = True)
335
  seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
336
 
337
- given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False)
338
  rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512))
339
  for i in range(given_mask.shape[1]):
340
  label = options[i]
341
  code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512)
342
  rec_tex += code * given_mask[:, i:i+1]
343
  tex_size = 512
344
- sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
345
  H = tex_size // 8; W = tex_size // 8
346
  noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
347
  dec_input = torch.cat((sine_wave, noise), dim = 1)
348
 
349
- weight = self.model.ChannelWeight(rec_tex)
350
- weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
351
- weight = torch.sigmoid(weight)
352
- dec_input *= weight
353
-
354
  rep_rec = self.model.G(dec_input, rec_tex)
355
  rep_rec = (rep_rec + 1) / 2.0
356
  return rep_rec
@@ -359,6 +358,7 @@ class Tester(TesterBase):
359
  rgb_img = Image.open(data_path)
360
  crop_size = self.args.crop_size
361
  i = 40; j = 40; h = crop_size; w = crop_size
 
362
  rgb_img = TF.crop(rgb_img, i, j, h, w)
363
 
364
  # compute superpixel
167
 
168
  sampled_code = tex_seg[:, tex_idx, :]
169
  rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
170
+ sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
171
  H = tex_size // 8; W = tex_size // 8
172
  noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
173
  dec_input = torch.cat((sine_wave, noise), dim = 1)
174
 
175
+ #weight = self.model.ChannelWeight(rec_tex)
176
+ #weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
177
+ #weight = torch.sigmoid(weight)
178
+ #dec_input *= weight
179
 
180
  rep_rec = self.model.G(dec_input, rec_tex)
181
  rep_rec = (rep_rec + 1) / 2.0
256
  given_mask = torch.from_numpy(given_mask)
257
  H, W = given_mask.shape[0], given_mask.shape[1]
258
  given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1))
259
+ given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False)
260
  mask_img_list = []
261
  for i in range(given_mask.shape[1]):
262
  part_img = self.to_pil(given_mask[0, i])
292
  #tex_size = (tex_size // 8) * 8
293
  with torch.no_grad():
294
  edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited')
295
+ col1, col2, col3 = st.columns([1, 1, 1])
296
  with col1:
 
 
 
297
  st.markdown("Input image")
298
  img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False)
299
  st.image(self.to_pil((img + 1) / 2.0))
300
  print(img.shape, edited.shape)
301
 
302
+ with col2:
303
+ st.markdown("Given mask")
304
+ seg = Image.open(mask_path).convert("L")
305
+ seg = np.asarray(seg)
306
+ seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1])
307
+ color_vq = self.draw_color_seg(seg)
308
+ color_vq = F.interpolate(color_vq, size = (512, 512), mode = 'bilinear', align_corners = False)
309
+ st.image(self.to_pil(color_vq))
310
+
311
  with col3:
312
  st.markdown("Synthesized texture image")
313
  st.image(self.to_pil(edited))
314
 
 
 
315
  st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
316
 
317
  def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False,
339
  tex_seg = poolfeat(conv_feats, softmax, avg = True)
340
  seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
341
 
 
342
  rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512))
343
  for i in range(given_mask.shape[1]):
344
  label = options[i]
345
  code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512)
346
  rec_tex += code * given_mask[:, i:i+1]
347
  tex_size = 512
348
+ sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
349
  H = tex_size // 8; W = tex_size // 8
350
  noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
351
  dec_input = torch.cat((sine_wave, noise), dim = 1)
352
 
 
 
 
 
 
353
  rep_rec = self.model.G(dec_input, rec_tex)
354
  rep_rec = (rep_rec + 1) / 2.0
355
  return rep_rec
358
  rgb_img = Image.open(data_path)
359
  crop_size = self.args.crop_size
360
  i = 40; j = 40; h = crop_size; w = crop_size
361
+ rgb_img = transforms.Resize(size=320)(rgb_img)
362
  rgb_img = TF.crop(rgb_img, i, j, h, w)
363
 
364
  # compute superpixel
data/images/108073.jpg DELETED
Binary file (76.4 kB)
data/{test_images → images}/130014.jpg RENAMED
File without changes
data/images/134008.jpg DELETED
Binary file (61.8 kB)
data/images/25098.jpg DELETED
Binary file (80.1 kB)
data/images/45077.jpg DELETED
Binary file (84.3 kB)
data/images/corn.jpg ADDED
data/test_images/108004.jpg DELETED
Binary file (107 kB)
data/test_images/130066.jpg DELETED
Binary file (62.9 kB)
data/test_images/16068.jpg DELETED
Binary file (79.2 kB)
data/test_images/208078.jpg DELETED
Binary file (81.3 kB)
data/test_images/223060.jpg DELETED
Binary file (82.5 kB)
data/test_images/388006.jpg DELETED
Binary file (76.2 kB)
data/test_images/78098.jpg DELETED
Binary file (103 kB)
models/week0417/__pycache__/loss.cpython-38.pyc ADDED
Binary file (14 kB). View file
models/week0417/__pycache__/meanshift_utils.cpython-37.pyc ADDED
Binary file (1.78 kB). View file
models/week0417/__pycache__/meanshift_utils.cpython-38.pyc ADDED
Binary file (1.79 kB). View file
models/week0417/__pycache__/model.cpython-37.pyc CHANGED
Binary files a/models/week0417/__pycache__/model.cpython-37.pyc and b/models/week0417/__pycache__/model.cpython-37.pyc differ
models/week0417/__pycache__/model.cpython-38.pyc CHANGED
Binary files a/models/week0417/__pycache__/model.cpython-38.pyc and b/models/week0417/__pycache__/model.cpython-38.pyc differ
models/week0417/__pycache__/nnutils.cpython-38.pyc CHANGED
Binary files a/models/week0417/__pycache__/nnutils.cpython-38.pyc and b/models/week0417/__pycache__/nnutils.cpython-38.pyc differ
models/week0417/__pycache__/taming_blocks.cpython-38.pyc CHANGED
Binary files a/models/week0417/__pycache__/taming_blocks.cpython-38.pyc and b/models/week0417/__pycache__/taming_blocks.cpython-38.pyc differ
models/week0417/meanshift_utils.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def pairwise_distances(x, y):
5
+ #Input: x is a Nxd matrix
6
+ # y is an optional Mxd matirx
7
+ #Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
8
+ # if y is not given then use 'y=x'.
9
+ #i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
10
+ x_norm = (x ** 2).sum(1).view(-1, 1)
11
+ y_t = torch.transpose(y, 0, 1)
12
+ y_norm = (y ** 2).sum(1).view(1, -1)
13
+ dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
14
+ return torch.clamp(dist, 0.0, np.inf)
15
+
16
+ def meanshift_cluster(pts, bandwidth, weights = None, meanshift_step = 15, step_size = 0.3):
17
+ """
18
+ meanshift written in pytorch
19
+ :param pts: input points
20
+ :param weights: weight per point during clustering
21
+ :return: clustered points
22
+ """
23
+ pts_steps = []
24
+ for i in range(meanshift_step):
25
+ Y = pairwise_distances(pts, pts)
26
+ K = torch.nn.functional.relu(bandwidth ** 2 - Y)
27
+ if weights is not None:
28
+ K = K * weights
29
+ P = torch.nn.functional.normalize(K, p=1, dim=0, eps=1e-10)
30
+ P = P.transpose(0, 1)
31
+ pts = step_size * (torch.matmul(P, pts) - pts) + pts
32
+ pts_steps.append(pts)
33
+ return pts_steps
34
+
35
+ def distance(a,b):
36
+ return torch.sqrt(((a-b)**2).sum())
37
+
38
+ def meanshift_assign(points, bandwidth):
39
+ cluster_ids = []
40
+ cluster_idx = 0
41
+ cluster_centers = []
42
+
43
+ for i, point in enumerate(points):
44
+ if(len(cluster_ids) == 0):
45
+ cluster_ids.append(cluster_idx)
46
+ cluster_centers.append(point)
47
+ cluster_idx += 1
48
+ else:
49
+ # assign to nearest cluster
50
+ #for j,center in enumerate(cluster_centers):
51
+ # dist = distance(point, center)
52
+ # if(dist < bandwidth):
53
+ # cluster_ids.append(j)
54
+ cdist = torch.cdist(point.unsqueeze(0), torch.stack(cluster_centers), p = 2)
55
+ nearest_idx = torch.argmin(cdist, dim = 1)
56
+ if cdist[0, nearest_idx] < bandwidth:
57
+ cluster_ids.append(nearest_idx)
58
+ else:
59
+ cluster_ids.append(cluster_idx)
60
+ cluster_centers.append(point)
61
+ cluster_idx += 1
62
+ return cluster_ids, cluster_centers
models/week0417/model.py CHANGED
@@ -7,21 +7,26 @@ import torchvision.transforms as transforms
7
  import torchvision.transforms.functional as TF
8
 
9
  from .taming_blocks import Encoder
 
10
  from .nnutils import SPADEResnetBlock, get_edges, initWave
11
 
12
  from libs.nnutils import poolfeat, upfeat
13
  from libs.utils import label2one_hot_torch
 
14
 
15
  from swapae.models.networks.stylegan2_layers import ConvLayer
16
  from torch_geometric.nn import GCNConv
17
  from torch_geometric.utils import softmax
18
- from .loss import styleLossMaskv3
 
 
 
19
 
20
  class GCN(nn.Module):
21
- def __init__(self, n_cluster, temperature = 1, add_self_loops = True, hidden_dim = 256):
22
  super().__init__()
23
- self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
24
- self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
25
  self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
26
  self.temperature = temperature
27
 
@@ -65,10 +70,7 @@ class GCN(nn.Module):
65
 
66
  # compute texture code w.r.t grouping
67
  pool_feat = poolfeat(conv_feat, s_, avg = True)
68
- # hard upsampling
69
- #hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1])
70
  feat = upfeat(pool_feat, s_)
71
- #feat = upfeat(pool_feat, hard_s_)
72
 
73
  prop_code.append(feat)
74
  sp_assign.append(pred_mask)
@@ -137,10 +139,7 @@ class AE(nn.Module):
137
  # encoder & decoder
138
  self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
139
  in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
140
- if args.dec_input_mode == 'sine_wave_noise':
141
- self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim)
142
- else:
143
- self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim)
144
 
145
  self.add_module(
146
  "ToTexCode",
@@ -150,40 +149,61 @@ class AE(nn.Module):
150
  ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
151
  )
152
  )
153
- self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, add_self_loops = (args.add_self_loops == 1), hidden_dim = args.hidden_dim)
154
 
155
  self.add_gcn_epoch = args.add_gcn_epoch
156
  self.add_clustering_epoch = args.add_clustering_epoch
157
  self.add_texture_epoch = args.add_texture_epoch
158
 
159
  self.patch_size = args.patch_size
 
160
  self.sine_wave_dim = args.spatial_code_dim
 
 
161
 
162
  # inpainting network
163
- self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
164
- self.dec_input_mode = args.dec_input_mode
165
- self.style_loss = styleLossMaskv3(device = args.device)
166
-
167
- if args.sine_weight:
168
- if args.dec_input_mode == 'sine_wave_noise':
169
- self.add_module(
170
- "ChannelWeight",
171
- nn.Sequential(
172
- ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
173
- ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
174
- ConvLayer(args.hidden_dim//4, args.spatial_code_dim*2, kernel_size=1, activate=False, bias=False, downsample=True)))
175
- else:
176
- self.add_module(
177
- "ChannelWeight",
178
- nn.Sequential(
179
- ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
180
- ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
181
- ConvLayer(args.hidden_dim//4, args.spatial_code_dim, kernel_size=1, activate=False, bias=False, downsample=True)))
182
-
183
- def get_sine_wave(self, GL, offset_mode = 'random'):
184
- img_size = GL.shape[-1] // 8
185
- GL = F.interpolate(GL, size = (img_size, img_size), mode = 'nearest')
186
- xv, yv = np.meshgrid(np.arange(img_size), np.arange(img_size),indexing='ij')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
187
  c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
188
  c = c.to(GL.device)
189
  # c: 1, 2, 28, 28
@@ -192,13 +212,170 @@ class AE(nn.Module):
192
  period = self.learnedWN(GL)
193
  # period: 1, 64, 28, 28
194
  raw = period * c
195
- if offset_mode == 'random':
196
- offset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
197
- offset = offset.repeat(1, 1, img_size, img_size)
198
- wave = torch.sin(raw[:, ::2] + raw[:, 1::2] + offset)
199
- elif offset_mode == 'rec':
200
- wave = torch.sin(raw[:, ::2] + raw[:, 1::2])
 
 
 
 
 
201
  return wave
202
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
203
  def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
204
- return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import torchvision.transforms.functional as TF
8
 
9
  from .taming_blocks import Encoder
10
+ from .loss import styleLossMaskv3
11
  from .nnutils import SPADEResnetBlock, get_edges, initWave
12
 
13
  from libs.nnutils import poolfeat, upfeat
14
  from libs.utils import label2one_hot_torch
15
+ from .meanshift_utils import meanshift_cluster, meanshift_assign
16
 
17
  from swapae.models.networks.stylegan2_layers import ConvLayer
18
  from torch_geometric.nn import GCNConv
19
  from torch_geometric.utils import softmax
20
+
21
+ import sys
22
+ sys.path.append('models/third_party/cython')
23
+ from connectivity import enforce_connectivity
24
 
25
  class GCN(nn.Module):
26
+ def __init__(self, n_cluster, temperature = 1, hidden_dim = 256):
27
  super().__init__()
28
+ self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
29
+ self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
30
  self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
31
  self.temperature = temperature
32
 
70
 
71
  # compute texture code w.r.t grouping
72
  pool_feat = poolfeat(conv_feat, s_, avg = True)
 
 
73
  feat = upfeat(pool_feat, s_)
 
74
 
75
  prop_code.append(feat)
76
  sp_assign.append(pred_mask)
139
  # encoder & decoder
140
  self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
141
  in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
142
+ self.G = SPADEGenerator(args.spatial_code_dim + 32, args.hidden_dim)
 
 
 
143
 
144
  self.add_module(
145
  "ToTexCode",
149
  ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
150
  )
151
  )
152
+ self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, hidden_dim = args.hidden_dim)
153
 
154
  self.add_gcn_epoch = args.add_gcn_epoch
155
  self.add_clustering_epoch = args.add_clustering_epoch
156
  self.add_texture_epoch = args.add_texture_epoch
157
 
158
  self.patch_size = args.patch_size
159
+ self.style_loss = styleLossMaskv3(device = args.device)
160
  self.sine_wave_dim = args.spatial_code_dim
161
+ self.noise_dim = 32
162
+ self.spatial_code_dim = args.spatial_code_dim
163
 
164
  # inpainting network
165
+ if args.spatial_code_dim > 0:
166
+ self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
167
+
168
+ self.add_module(
169
+ "Amplitude",
170
+ nn.Sequential(
171
+ nn.Conv2d(args.hidden_dim, args.hidden_dim//2, 1, 1, 0),
172
+ nn.Conv2d(args.hidden_dim//2, args.hidden_dim//4, 1, 1, 0),
173
+ nn.Conv2d(args.hidden_dim//4, args.spatial_code_dim, 1, 1, 0)
174
+ )
175
+ )
176
+
177
+ self.bandwidth = 3.0
178
+
179
+ def sample_patch_from_mask(self, mask, patch_num = 10, patch_size = 64):
180
+ """
181
+ - Sample `patch_num` patches of size `patch_size*patch_size` w.r.t given mask
182
+ """
183
+ nonzeros = torch.nonzero(mask.view(-1)).squeeze()
184
+ n = len(nonzeros)
185
+ xys = []
186
+ imgH, imgW = mask.shape
187
+ half_patch = patch_size // 2
188
+ iter_num = 0
189
+ while len(xys) < patch_num:
190
+ id = (torch.ones(n)*1.0/n).multinomial(num_samples=1, replacement=False)
191
+ rx = nonzeros[id] // imgW
192
+ ry = nonzeros[id] % imgW
193
+ top = max(0, rx - half_patch)
194
+ bot = min(imgH, rx + half_patch)
195
+ left = max(0, ry - half_patch)
196
+ right = min(imgW, ry + half_patch)
197
+ patch_mask = mask[top:bot, left:right]
198
+ if torch.sum(patch_mask) / (patch_size ** 2) > 0.5 or iter_num > 20:
199
+ xys.append([top, bot, left, right])
200
+ iter_num += 1
201
+ return xys
202
+
203
+ def get_sine_wave(self, GL, offset_mode = 'rec'):
204
+ imgH, imgW = GL.shape[-2]//8, GL.shape[-1] // 8
205
+ GL = F.interpolate(GL, size = (imgH, imgW), mode = 'nearest')
206
+ xv, yv = np.meshgrid(np.arange(imgH), np.arange(imgW),indexing='ij')
207
  c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
208
  c = c.to(GL.device)
209
  # c: 1, 2, 28, 28
212
  period = self.learnedWN(GL)
213
  # period: 1, 64, 28, 28
214
  raw = period * c
215
+
216
+ # random offset
217
+ roffset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
218
+ roffset = roffset.repeat(1, 1, imgH, imgW)
219
+ rwave = torch.sin(raw[:, ::2] + raw[:, 1::2] + roffset)
220
+
221
+ # zero offset
222
+ zwave = torch.sin(raw[:, ::2] + raw[:, 1::2])
223
+ A = self.Amplitude(GL)
224
+ A = torch.sigmoid(A)
225
+ wave = torch.cat((zwave, rwave)) * A.repeat(2, 1, 1, 1)
226
  return wave
227
 
228
+ def syn_tex(self, tex_code, mask, imgH, imgW, offset_mode = 'rec', tex_idx = None):
229
+ # synthesize all textures
230
+ # spatial: B x 256 x 14 x 14
231
+ # tex_code: B x N x 256
232
+ B, N, _ = tex_code.shape
233
+ H = imgH // 8
234
+ W = imgW // 8
235
+
236
+ # randomly sample a texture and synthesize it
237
+ # throw away small texture segments
238
+ areas = torch.sum(mask, dim=(2, 3))
239
+ valid_idxs = torch.nonzero(areas[0] / (imgH * imgW) > 0.01).squeeze(-1)
240
+ if tex_idx is None or tex_idx >= tex_code.shape[1]:
241
+ tex_idx = valid_idxs[torch.multinomial(areas[0, valid_idxs], 1).squeeze()]
242
+ else:
243
+ sorted_list = torch.argsort(areas, dim = 1, descending = True)
244
+ tex_idx = sorted_list[0, tex_idx]
245
+ sampled_code = tex_code[:, tex_idx, :]
246
+ rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, imgH, imgW)
247
+
248
+ # Decoder: Spatial & Texture code -> Image
249
+ if self.noise_dim == 0:
250
+ dec_input = self.get_sine_wave(rec_tex, offset_mode)
251
+ elif self.spatial_code_dim == 0:
252
+ dec_input = torch.randn(rec_tex.shape[0], self.noise_dim, H, W).to(tex_code.device)
253
+ else:
254
+ sine_wave = self.get_sine_wave(rec_tex, offset_mode)
255
+ noise = torch.randn(sine_wave.shape[0], self.noise_dim, H, W).to(tex_code.device)
256
+ dec_input = torch.cat((sine_wave, noise), dim = 1)
257
+
258
+ tex_syn = self.G(dec_input, rec_tex.repeat(dec_input.shape[0], 1, 1, 1))
259
+
260
+ return tex_syn, tex_idx
261
+
262
+ def sample_tex_patches(self, tex_idx, rgb_img, rep_rec, mask, patch_num = 10):
263
+ patches = []
264
+ masks = []
265
+ patch_masks = []
266
+ # sample patches from input image and reconstruction
267
+ for i in range(rgb_img.shape[0]):
268
+ # WARNING: : This only works for batch_size = 1 for now
269
+ maski = mask[i, tex_idx]
270
+ masks.append(maski.unsqueeze(0))
271
+ xys = self.sample_patch_from_mask(maski, patch_num = patch_num, patch_size = self.patch_size)
272
+ # sample 10 patches from input image & reconstruction w.r.t group mask
273
+ for k in range(patch_num):
274
+ top, bot, left, right = xys[k]
275
+ patch_ = rgb_img[i, :, top:bot, left:right]
276
+ patch_mask_ = maski[top:bot, left:right]
277
+
278
+ # In case the patch is on the boundary and smaller than patch_size
279
+ # We put the patch at some random place of a black image
280
+ h, w = patch_.shape[-2:]
281
+ x = 0; y = 0
282
+ if h < self.patch_size:
283
+ x = np.random.randint(0, self.patch_size - h)
284
+ if w < self.patch_size:
285
+ y = np.random.randint(0, self.patch_size - w)
286
+ patch = torch.zeros(1, 3, self.patch_size, self.patch_size).to(patch_.device)
287
+ patch_mask = torch.zeros(1, 1, self.patch_size, self.patch_size).to(patch_.device)
288
+ patch[:, :, x:x+h, y:y+w] = patch_
289
+ patch_mask[:, :, x:x+h, y:y+w] = patch_mask_
290
+ patches.append(patch)
291
+ patch_masks.append(patch_mask)
292
+ patches = torch.cat(patches)
293
+ masks = torch.stack(masks)
294
+ patch_masks = torch.cat(patch_masks)
295
+
296
+ # sample patches from synthesized texture
297
+ tex_patch_size = self.patch_size
298
+ rep_patches = []
299
+ for k in range(patch_num):
300
+ i, j, h, w = transforms.RandomCrop.get_params(rep_rec, output_size=(tex_patch_size, tex_patch_size))
301
+ rep_rec_patch = TF.crop(rep_rec, i, j, h, w)
302
+ rep_patches.append(rep_rec_patch)
303
+ rep_patches = torch.stack(rep_patches, dim = 1)
304
+ rep_patches = rep_patches.view(-1, 3, tex_patch_size, tex_patch_size)
305
+ return masks, patch_masks, patches, rep_patches
306
+
307
  def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
308
+ #self.patch_size = np.random.randint(64, 160)
309
+ B, _, imgH, imgW = rgb_img.shape
310
+ outputs = {}
311
+ rec_feat_list = []
312
+ seg_map = [torch.argmax(slic.cpu(), dim = 1)]
313
+
314
+ # Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
315
+ conv_feat, layer_feats = self.enc(rgb_img)
316
+ B, C, H, W = conv_feat.shape
317
+
318
+ # Texture code for each superpixel
319
+ tex_code = self.ToTexCode(conv_feat)
320
+
321
+ code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
322
+ pool_code = poolfeat(code, slic, avg = True)
323
+
324
+ if epoch >= self.add_gcn_epoch:
325
+ prop_code, sp_assign, conv_feats = self.gcn(pool_code, slic, (self.add_clustering_epoch <= epoch))
326
+ softmax = F.softmax(sp_assign * self.gcn.temperature, dim = 1)
327
+ rec_feat_list.append(prop_code)
328
+ seg_map = [torch.argmax(sp_assign.cpu(), dim = 1)]
329
+ else:
330
+ rec_code = upfeat(pool_code, slic)
331
+ rec_feat_list.append(rec_code)
332
+ softmax = slic
333
+
334
+ # Texture synthesis
335
+ if epoch >= self.add_texture_epoch:
336
+ sp_feat = poolfeat(conv_feats, slic, avg = True).squeeze(0)
337
+ pts = meanshift_cluster(sp_feat, self.bandwidth, meanshift_step = 15)[-1]
338
+ with torch.no_grad():
339
+ sp_assign, _ = meanshift_assign(pts, self.bandwidth)
340
+ sp_assign = torch.tensor(sp_assign).unsqueeze(-1).to(slic.device).float()
341
+ sp_assign = upfeat(sp_assign, slic)
342
+ seg = label2one_hot_torch(sp_assign, C = sp_assign.max().long() + 1)
343
+ seg_map = [torch.argmax(seg.cpu(), dim = 1)]
344
+
345
+ # texture code for each connected group
346
+ tex_seg = poolfeat(conv_feats, seg, avg = True)
347
+ if test:
348
+ rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 564, 564, tex_idx = tex_idx)
349
+ #rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 1024, 1024, tex_idx = tex_idx)
350
+ else:
351
+ rep_rec, tex_idx = self.syn_tex(tex_seg, seg, imgH, imgW, tex_idx = tex_idx)
352
+ rep_rec = (rep_rec + 1) / 2.0
353
+ rgb_img = (rgb_img + 1) / 2.0
354
+
355
+ # sample patches from input image, reconstruction & synthesized texture
356
+ # zero offset
357
+ zmasks, zpatch_masks, zpatches, zrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[:1], seg)
358
+ # random offset
359
+ rmasks, rpatch_masks, rpatches, rrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[1:], seg)
360
+ masks = torch.cat((zmasks, rmasks))
361
+ patch_masks = torch.cat((zpatch_masks, rpatch_masks))
362
+ patches = torch.cat((zpatches, rpatches))
363
+ rep_patches = torch.cat((zrep_patches, rrep_patches))
364
+
365
+ # Gram matrix matching loss between:
366
+ # - patches from synthesized texture v.s. patches from input image
367
+ # - patches from reconstruction v.s. patches from input image
368
+ outputs['style_loss'] = self.style_loss.forward_patch_img(rep_patches, rgb_img.repeat(2, 1, 1, 1), masks)
369
+
370
+ outputs['rep_rec'] = rep_rec
371
+ outputs['masks'] = masks
372
+ outputs['patches'] = patches.view(-1, 3, self.patch_size, self.patch_size)
373
+ outputs['patch_masks'] = patch_masks
374
+ outputs['rep_patches'] = rep_patches * patch_masks + patches * (1 - patch_masks)
375
+
376
+ outputs['gt'] = rgb_img
377
+ bp_tex = rep_rec[:1, :, :imgH, :imgW] * masks[:1] + rgb_img * (1 - masks[:1])
378
+ outputs['rec'] = bp_tex
379
+
380
+ outputs['HA'] = torch.cat(seg_map)
381
+ return outputs
models/week0417/model_bk.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision.transforms as transforms
7
+ import torchvision.transforms.functional as TF
8
+
9
+ from .taming_blocks import Encoder
10
+ from .nnutils import SPADEResnetBlock, get_edges, initWave
11
+
12
+ from libs.nnutils import poolfeat, upfeat
13
+ from libs.utils import label2one_hot_torch
14
+
15
+ from swapae.models.networks.stylegan2_layers import ConvLayer
16
+ from torch_geometric.nn import GCNConv
17
+ from torch_geometric.utils import softmax
18
+ from .loss import styleLossMaskv3
19
+
20
+ class GCN(nn.Module):
21
+ def __init__(self, n_cluster, temperature = 1, add_self_loops = True, hidden_dim = 256):
22
+ super().__init__()
23
+ self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
24
+ self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
25
+ self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
26
+ self.temperature = temperature
27
+
28
+ def compute_edge_score_softmax(self, raw_edge_score, edge_index, num_nodes):
29
+ return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)
30
+
31
+ def compute_edge_weight(self, node_feature, edge_index):
32
+ src_feat = torch.gather(node_feature, 0, edge_index[0].unsqueeze(1).repeat(1, node_feature.shape[1]))
33
+ tgt_feat = torch.gather(node_feature, 0, edge_index[1].unsqueeze(1).repeat(1, node_feature.shape[1]))
34
+ raw_edge_weight = nn.CosineSimilarity(dim=1, eps=1e-6)(src_feat, tgt_feat)
35
+ edge_weight = self.compute_edge_score_softmax(raw_edge_weight, edge_index, node_feature.shape[0])
36
+ return raw_edge_weight.squeeze(), edge_weight.squeeze()
37
+
38
+ def forward(self, sp_code, slic, clustering = False):
39
+ edges, aff = get_edges(torch.argmax(slic, dim = 1).unsqueeze(1), sp_code.shape[1])
40
+ prop_code = []
41
+ sp_assign = []
42
+ edge_weights = []
43
+ conv_feats = []
44
+ for i in range(sp_code.shape[0]):
45
+ # compute edge weight
46
+ edge_index = edges[i]
47
+ raw_edge_weight, edge_weight = self.compute_edge_weight(sp_code[i], edge_index)
48
+ feat = self.gcnconv1(sp_code[i], edge_index, edge_weight = edge_weight)
49
+ raw_edge_weight, edge_weight = self.compute_edge_weight(feat, edge_index)
50
+ edge_weights.append(raw_edge_weight)
51
+ feat = F.leaky_relu(feat, 0.2)
52
+ feat = self.gcnconv2(feat, edge_index, edge_weight = edge_weight)
53
+
54
+ # maybe clustering
55
+ conv_feat = upfeat(feat, slic[i:i+1])
56
+ conv_feats.append(conv_feat)
57
+ if not clustering:
58
+ feat = conv_feat
59
+ pred_mask = slic[i:i+1]
60
+ else:
61
+ pred_mask = self.pool1(conv_feat)
62
+ # enforce pixels belong to the same superpixel to have same grouping label
63
+ pred_mask = upfeat(poolfeat(pred_mask, slic[i:i+1]), slic[i:i+1])
64
+ s_ = F.softmax(pred_mask * self.temperature, dim = 1)
65
+
66
+ # compute texture code w.r.t grouping
67
+ pool_feat = poolfeat(conv_feat, s_, avg = True)
68
+ # hard upsampling
69
+ #hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1])
70
+ feat = upfeat(pool_feat, s_)
71
+ #feat = upfeat(pool_feat, hard_s_)
72
+
73
+ prop_code.append(feat)
74
+ sp_assign.append(pred_mask)
75
+ prop_code = torch.cat(prop_code)
76
+ conv_feats = torch.cat(conv_feats)
77
+ return prop_code, torch.cat(sp_assign), conv_feats
78
+
79
+ class SPADEGenerator(nn.Module):
80
+ def __init__(self, in_dim, hidden_dim):
81
+ super().__init__()
82
+ nf = hidden_dim // 16
83
+
84
+ self.head_0 = SPADEResnetBlock(in_dim, 16 * nf)
85
+
86
+ self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
87
+ self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)
88
+
89
+ self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
90
+ self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
91
+ self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf)
92
+ self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)
93
+
94
+ final_nc = nf
95
+
96
+ self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
97
+
98
+ self.up = nn.Upsample(scale_factor=2)
99
+
100
+ def forward(self, sine_wave, texon):
101
+
102
+ x = self.head_0(sine_wave, texon)
103
+
104
+ x = self.up(x)
105
+ x = self.G_middle_0(x, texon)
106
+ x = self.G_middle_1(x, texon)
107
+
108
+ x = self.up(x)
109
+ x = self.up_0(x, texon)
110
+ x = self.up(x)
111
+ x = self.up_1(x, texon)
112
+ #x = self.up(x)
113
+ x = self.up_2(x, texon)
114
+ #x = self.up(x)
115
+ x = self.up_3(x, texon)
116
+
117
+ x = self.conv_img(F.leaky_relu(x, 2e-1))
118
+ return x
119
+
120
+ class Waver(nn.Module):
121
+ def __init__(self, tex_code_dim, zPeriodic):
122
+ super(Waver, self).__init__()
123
+ K = tex_code_dim
124
+ layers = [nn.Conv2d(tex_code_dim, K, 1)]
125
+ layers += [nn.ReLU(True)]
126
+ layers += [nn.Conv2d(K, 2 * zPeriodic, 1)]
127
+ self.learnedWN = nn.Sequential(*layers)
128
+ self.waveNumbers = initWave(zPeriodic)
129
+
130
+ def forward(self, GLZ=None):
131
+ return (self.waveNumbers.to(GLZ.device) + self.learnedWN(GLZ))
132
+
133
+ class AE(nn.Module):
134
+ def __init__(self, args, **ignore_kwargs):
135
+ super(AE, self).__init__()
136
+
137
+ # encoder & decoder
138
+ self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
139
+ in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
140
+ if args.dec_input_mode == 'sine_wave_noise':
141
+ self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim)
142
+ else:
143
+ self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim)
144
+
145
+ self.add_module(
146
+ "ToTexCode",
147
+ nn.Sequential(
148
+ ConvLayer(args.hidden_dim, args.hidden_dim, kernel_size=3, activate=True, bias=True),
149
+ ConvLayer(args.hidden_dim, args.tex_code_dim, kernel_size=3, activate=True, bias=True),
150
+ ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
151
+ )
152
+ )
153
+ self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, add_self_loops = (args.add_self_loops == 1), hidden_dim = args.hidden_dim)
154
+
155
+ self.add_gcn_epoch = args.add_gcn_epoch
156
+ self.add_clustering_epoch = args.add_clustering_epoch
157
+ self.add_texture_epoch = args.add_texture_epoch
158
+
159
+ self.patch_size = args.patch_size
160
+ self.sine_wave_dim = args.spatial_code_dim
161
+
162
+ # inpainting network
163
+ self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
164
+ self.dec_input_mode = args.dec_input_mode
165
+ self.style_loss = styleLossMaskv3(device = args.device)
166
+
167
+ if args.sine_weight:
168
+ if args.dec_input_mode == 'sine_wave_noise':
169
+ self.add_module(
170
+ "ChannelWeight",
171
+ nn.Sequential(
172
+ ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
173
+ ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
174
+ ConvLayer(args.hidden_dim//4, args.spatial_code_dim*2, kernel_size=1, activate=False, bias=False, downsample=True)))
175
+ else:
176
+ self.add_module(
177
+ "ChannelWeight",
178
+ nn.Sequential(
179
+ ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
180
+ ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
181
+ ConvLayer(args.hidden_dim//4, args.spatial_code_dim, kernel_size=1, activate=False, bias=False, downsample=True)))
182
+
183
+ def get_sine_wave(self, GL, offset_mode = 'random'):
184
+ img_size = GL.shape[-1] // 8
185
+ GL = F.interpolate(GL, size = (img_size, img_size), mode = 'nearest')
186
+ xv, yv = np.meshgrid(np.arange(img_size), np.arange(img_size),indexing='ij')
187
+ c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
188
+ c = c.to(GL.device)
189
+ # c: 1, 2, 28, 28
190
+ c = c.repeat(GL.shape[0], self.sine_wave_dim, 1, 1)
191
+ # c: 1, 64, 28, 28
192
+ period = self.learnedWN(GL)
193
+ # period: 1, 64, 28, 28
194
+ raw = period * c
195
+ if offset_mode == 'random':
196
+ offset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
197
+ offset = offset.repeat(1, 1, img_size, img_size)
198
+ wave = torch.sin(raw[:, ::2] + raw[:, 1::2] + offset)
199
+ elif offset_mode == 'rec':
200
+ wave = torch.sin(raw[:, ::2] + raw[:, 1::2])
201
+ return wave
202
+
203
+ def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
204
+ return
swapae/models/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/swapae/models/__pycache__/__init__.cpython-38.pyc and b/swapae/models/__pycache__/__init__.cpython-38.pyc differ
swapae/models/__pycache__/base_model.cpython-38.pyc CHANGED
Binary files a/swapae/models/__pycache__/base_model.cpython-38.pyc and b/swapae/models/__pycache__/base_model.cpython-38.pyc differ
swapae/models/networks/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/__pycache__/__init__.cpython-38.pyc and b/swapae/models/networks/__pycache__/__init__.cpython-38.pyc differ
swapae/models/networks/__pycache__/base_network.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/__pycache__/base_network.cpython-38.pyc and b/swapae/models/networks/__pycache__/base_network.cpython-38.pyc differ
swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc and b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc differ
swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc differ
swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc differ
swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc differ
swapae/util/__pycache__/__init__.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/__init__.cpython-38.pyc and b/swapae/util/__pycache__/__init__.cpython-38.pyc differ
swapae/util/__pycache__/html.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/html.cpython-38.pyc and b/swapae/util/__pycache__/html.cpython-38.pyc differ
swapae/util/__pycache__/iter_counter.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/iter_counter.cpython-38.pyc and b/swapae/util/__pycache__/iter_counter.cpython-38.pyc differ
swapae/util/__pycache__/metric_tracker.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/metric_tracker.cpython-38.pyc and b/swapae/util/__pycache__/metric_tracker.cpython-38.pyc differ
swapae/util/__pycache__/util.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/util.cpython-38.pyc and b/swapae/util/__pycache__/util.cpython-38.pyc differ
swapae/util/__pycache__/visualizer.cpython-38.pyc CHANGED
Binary files a/swapae/util/__pycache__/visualizer.cpython-38.pyc and b/swapae/util/__pycache__/visualizer.cpython-38.pyc differ
tmp/0.png ADDED
tmp/1.png ADDED
tmp/2.png ADDED
tmp/3.png ADDED
tmp/4.png ADDED
tmp/5.png ADDED
tmp/6.png ADDED
tmp/7.png ADDED
tmp/8.png ADDED
tmp/9.png ADDED
tmp/tmp.png ADDED
weights/108004/exp_args.json DELETED
@@ -1,63 +0,0 @@
1
- {
2
- "data_path": "/home/xtli/DATA/BSR_processed/train",
3
- "img_path": "data/test_images/108004.jpg",
4
- "test_path": null,
5
- "crop_size": 224,
6
- "scale_size": null,
7
- "batch_size": 1,
8
- "workers": 4,
9
- "pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth",
10
- "hidden_dim": 256,
11
- "spatial_code_dim": 32,
12
- "tex_code_dim": 256,
13
- "exp_name": "04-18/108004",
14
- "project_name": "ssn_transformer",
15
- "nepochs": 20,
16
- "lr": 5e-05,
17
- "momentum": 0.5,
18
- "beta": 0.999,
19
- "lr_decay_freq": 3000,
20
- "save_freq": 1000,
21
- "save_freq_iter": 2000000000000,
22
- "log_freq": 10,
23
- "display_freq": 100,
24
- "use_wandb": 0,
25
- "work_dir": "/home/xtli/WORKDIR",
26
- "out_dir": "/home/xtli/WORKDIR/04-18/108004",
27
- "local_rank": 0,
28
- "dataset": "dataset",
29
- "config_file": "models/week0417/json/single_scale_grouping_ft.json",
30
- "lambda_L1": 1,
31
- "lambda_Perceptual": 1.0,
32
- "lambda_PatchGAN": 1.0,
33
- "lambda_GAN": 1,
34
- "add_gan_epoch": 0,
35
- "lambda_kld_loss": 1e-06,
36
- "lambda_style_loss": 1.0,
37
- "lambda_feat": 10.0,
38
- "use_slic": true,
39
- "patch_size": 64,
40
- "netPatchD_scale_capacity": 4.0,
41
- "netPatchD_max_nc": 384,
42
- "netPatchD": "StyleGAN2",
43
- "use_antialias": true,
44
- "patch_use_aggregation": false,
45
- "lambda_R1": 1.0,
46
- "lambda_ffl_loss": 1.0,
47
- "lambda_patch_R1": 1.0,
48
- "R1_once_every": 16,
49
- "add_self_loops": 1,
50
- "test_time": false,
51
- "sp_num": 196,
52
- "label_path": "/home/xtli/DATA/BSR/BSDS500/data/groundTruth",
53
- "model_name": "model",
54
- "num_D": 2,
55
- "n_layers_D": 3,
56
- "n_cluster": 10,
57
- "temperature": 23,
58
- "add_gcn_epoch": 0,
59
- "add_clustering_epoch": 0,
60
- "add_texture_epoch": 0,
61
- "dec_input_mode": "sine_wave_noise",
62
- "sine_weight": true
63
- }