Spaces:
Runtime error
Runtime error
sunshineatnoon
commited on
Commit
•
827b81f
1
Parent(s):
733ce12
new_model
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- app.py +18 -18
- data/images/108073.jpg +0 -0
- data/{test_images → images}/130014.jpg +0 -0
- data/images/134008.jpg +0 -0
- data/images/25098.jpg +0 -0
- data/images/45077.jpg +0 -0
- data/images/corn.jpg +0 -0
- data/test_images/108004.jpg +0 -0
- data/test_images/130066.jpg +0 -0
- data/test_images/16068.jpg +0 -0
- data/test_images/208078.jpg +0 -0
- data/test_images/223060.jpg +0 -0
- data/test_images/388006.jpg +0 -0
- data/test_images/78098.jpg +0 -0
- models/week0417/__pycache__/loss.cpython-38.pyc +0 -0
- models/week0417/__pycache__/meanshift_utils.cpython-37.pyc +0 -0
- models/week0417/__pycache__/meanshift_utils.cpython-38.pyc +0 -0
- models/week0417/__pycache__/model.cpython-37.pyc +0 -0
- models/week0417/__pycache__/model.cpython-38.pyc +0 -0
- models/week0417/__pycache__/nnutils.cpython-38.pyc +0 -0
- models/week0417/__pycache__/taming_blocks.cpython-38.pyc +0 -0
- models/week0417/meanshift_utils.py +62 -0
- models/week0417/model.py +220 -43
- models/week0417/model_bk.py +204 -0
- swapae/models/__pycache__/__init__.cpython-38.pyc +0 -0
- swapae/models/__pycache__/base_model.cpython-38.pyc +0 -0
- swapae/models/networks/__pycache__/__init__.cpython-38.pyc +0 -0
- swapae/models/networks/__pycache__/base_network.cpython-38.pyc +0 -0
- swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc +0 -0
- swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc +0 -0
- swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc +0 -0
- swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc +0 -0
- swapae/util/__pycache__/__init__.cpython-38.pyc +0 -0
- swapae/util/__pycache__/html.cpython-38.pyc +0 -0
- swapae/util/__pycache__/iter_counter.cpython-38.pyc +0 -0
- swapae/util/__pycache__/metric_tracker.cpython-38.pyc +0 -0
- swapae/util/__pycache__/util.cpython-38.pyc +0 -0
- swapae/util/__pycache__/visualizer.cpython-38.pyc +0 -0
- tmp/0.png +0 -0
- tmp/1.png +0 -0
- tmp/2.png +0 -0
- tmp/3.png +0 -0
- tmp/4.png +0 -0
- tmp/5.png +0 -0
- tmp/6.png +0 -0
- tmp/7.png +0 -0
- tmp/8.png +0 -0
- tmp/9.png +0 -0
- tmp/tmp.png +0 -0
- weights/108004/exp_args.json +0 -63
app.py
CHANGED
@@ -167,15 +167,15 @@ class Tester(TesterBase):
|
|
167 |
|
168 |
sampled_code = tex_seg[:, tex_idx, :]
|
169 |
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
|
170 |
-
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
|
171 |
H = tex_size // 8; W = tex_size // 8
|
172 |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
|
173 |
dec_input = torch.cat((sine_wave, noise), dim = 1)
|
174 |
|
175 |
-
weight = self.model.ChannelWeight(rec_tex)
|
176 |
-
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
|
177 |
-
weight = torch.sigmoid(weight)
|
178 |
-
dec_input *= weight
|
179 |
|
180 |
rep_rec = self.model.G(dec_input, rec_tex)
|
181 |
rep_rec = (rep_rec + 1) / 2.0
|
@@ -256,6 +256,7 @@ class Tester(TesterBase):
|
|
256 |
given_mask = torch.from_numpy(given_mask)
|
257 |
H, W = given_mask.shape[0], given_mask.shape[1]
|
258 |
given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1))
|
|
|
259 |
mask_img_list = []
|
260 |
for i in range(given_mask.shape[1]):
|
261 |
part_img = self.to_pil(given_mask[0, i])
|
@@ -291,22 +292,26 @@ class Tester(TesterBase):
|
|
291 |
#tex_size = (tex_size // 8) * 8
|
292 |
with torch.no_grad():
|
293 |
edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited')
|
294 |
-
col1, col2, col3
|
295 |
with col1:
|
296 |
-
st.markdown("")
|
297 |
-
|
298 |
-
with col2:
|
299 |
st.markdown("Input image")
|
300 |
img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False)
|
301 |
st.image(self.to_pil((img + 1) / 2.0))
|
302 |
print(img.shape, edited.shape)
|
303 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
304 |
with col3:
|
305 |
st.markdown("Synthesized texture image")
|
306 |
st.image(self.to_pil(edited))
|
307 |
|
308 |
-
with col4:
|
309 |
-
st.markdown("")
|
310 |
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
|
311 |
|
312 |
def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False,
|
@@ -334,23 +339,17 @@ class Tester(TesterBase):
|
|
334 |
tex_seg = poolfeat(conv_feats, softmax, avg = True)
|
335 |
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
|
336 |
|
337 |
-
given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False)
|
338 |
rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512))
|
339 |
for i in range(given_mask.shape[1]):
|
340 |
label = options[i]
|
341 |
code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512)
|
342 |
rec_tex += code * given_mask[:, i:i+1]
|
343 |
tex_size = 512
|
344 |
-
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')
|
345 |
H = tex_size // 8; W = tex_size // 8
|
346 |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
|
347 |
dec_input = torch.cat((sine_wave, noise), dim = 1)
|
348 |
|
349 |
-
weight = self.model.ChannelWeight(rec_tex)
|
350 |
-
weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
|
351 |
-
weight = torch.sigmoid(weight)
|
352 |
-
dec_input *= weight
|
353 |
-
|
354 |
rep_rec = self.model.G(dec_input, rec_tex)
|
355 |
rep_rec = (rep_rec + 1) / 2.0
|
356 |
return rep_rec
|
@@ -359,6 +358,7 @@ class Tester(TesterBase):
|
|
359 |
rgb_img = Image.open(data_path)
|
360 |
crop_size = self.args.crop_size
|
361 |
i = 40; j = 40; h = crop_size; w = crop_size
|
|
|
362 |
rgb_img = TF.crop(rgb_img, i, j, h, w)
|
363 |
|
364 |
# compute superpixel
|
|
|
167 |
|
168 |
sampled_code = tex_seg[:, tex_idx, :]
|
169 |
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, tex_size, tex_size)
|
170 |
+
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
|
171 |
H = tex_size // 8; W = tex_size // 8
|
172 |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
|
173 |
dec_input = torch.cat((sine_wave, noise), dim = 1)
|
174 |
|
175 |
+
#weight = self.model.ChannelWeight(rec_tex)
|
176 |
+
#weight = F.adaptive_avg_pool2d(weight, output_size = (1)).view(weight.shape[0], -1, 1, 1)
|
177 |
+
#weight = torch.sigmoid(weight)
|
178 |
+
#dec_input *= weight
|
179 |
|
180 |
rep_rec = self.model.G(dec_input, rec_tex)
|
181 |
rep_rec = (rep_rec + 1) / 2.0
|
|
|
256 |
given_mask = torch.from_numpy(given_mask)
|
257 |
H, W = given_mask.shape[0], given_mask.shape[1]
|
258 |
given_mask = label2one_hot_torch(given_mask.view(1, 1, H, W), C = (given_mask.max()+1))
|
259 |
+
given_mask = F.interpolate(given_mask, size = (512, 512), mode = 'bilinear', align_corners = False)
|
260 |
mask_img_list = []
|
261 |
for i in range(given_mask.shape[1]):
|
262 |
part_img = self.to_pil(given_mask[0, i])
|
|
|
292 |
#tex_size = (tex_size // 8) * 8
|
293 |
with torch.no_grad():
|
294 |
edited = self.model_forward_editing(self.data, self.slic, options=options, given_mask=given_mask, return_type = 'edited')
|
295 |
+
col1, col2, col3 = st.columns([1, 1, 1])
|
296 |
with col1:
|
|
|
|
|
|
|
297 |
st.markdown("Input image")
|
298 |
img = F.interpolate(self.data, size = edited.shape[-2:], mode = 'bilinear', align_corners = False)
|
299 |
st.image(self.to_pil((img + 1) / 2.0))
|
300 |
print(img.shape, edited.shape)
|
301 |
|
302 |
+
with col2:
|
303 |
+
st.markdown("Given mask")
|
304 |
+
seg = Image.open(mask_path).convert("L")
|
305 |
+
seg = np.asarray(seg)
|
306 |
+
seg = torch.from_numpy(seg).view(1, 1, seg.shape[0], seg.shape[1])
|
307 |
+
color_vq = self.draw_color_seg(seg)
|
308 |
+
color_vq = F.interpolate(color_vq, size = (512, 512), mode = 'bilinear', align_corners = False)
|
309 |
+
st.image(self.to_pil(color_vq))
|
310 |
+
|
311 |
with col3:
|
312 |
st.markdown("Synthesized texture image")
|
313 |
st.image(self.to_pil(edited))
|
314 |
|
|
|
|
|
315 |
st.markdown('<p class="big-font">You can choose another image from the examplar images on the top and start again!</p>', unsafe_allow_html=True)
|
316 |
|
317 |
def model_forward_editing(self, rgb_img, slic, epoch = 1000, test_time = False,
|
|
|
339 |
tex_seg = poolfeat(conv_feats, softmax, avg = True)
|
340 |
seg = label2one_hot_torch(torch.argmax(softmax, dim = 1).unsqueeze(1), C = softmax.shape[1])
|
341 |
|
|
|
342 |
rec_tex = torch.zeros((1, tex_seg.shape[-1], 512, 512))
|
343 |
for i in range(given_mask.shape[1]):
|
344 |
label = options[i]
|
345 |
code = tex_seg[0, label, :].view(1, -1, 1, 1).repeat(1, 1, 512, 512)
|
346 |
rec_tex += code * given_mask[:, i:i+1]
|
347 |
tex_size = 512
|
348 |
+
sine_wave = self.model.get_sine_wave(rec_tex, 'rec')[:1]
|
349 |
H = tex_size // 8; W = tex_size // 8
|
350 |
noise = torch.randn(B, self.model.sine_wave_dim, H, W).to(tex_code.device)
|
351 |
dec_input = torch.cat((sine_wave, noise), dim = 1)
|
352 |
|
|
|
|
|
|
|
|
|
|
|
353 |
rep_rec = self.model.G(dec_input, rec_tex)
|
354 |
rep_rec = (rep_rec + 1) / 2.0
|
355 |
return rep_rec
|
|
|
358 |
rgb_img = Image.open(data_path)
|
359 |
crop_size = self.args.crop_size
|
360 |
i = 40; j = 40; h = crop_size; w = crop_size
|
361 |
+
rgb_img = transforms.Resize(size=320)(rgb_img)
|
362 |
rgb_img = TF.crop(rgb_img, i, j, h, w)
|
363 |
|
364 |
# compute superpixel
|
data/images/108073.jpg
DELETED
Binary file (76.4 kB)
|
|
data/{test_images → images}/130014.jpg
RENAMED
File without changes
|
data/images/134008.jpg
DELETED
Binary file (61.8 kB)
|
|
data/images/25098.jpg
DELETED
Binary file (80.1 kB)
|
|
data/images/45077.jpg
DELETED
Binary file (84.3 kB)
|
|
data/images/corn.jpg
ADDED
data/test_images/108004.jpg
DELETED
Binary file (107 kB)
|
|
data/test_images/130066.jpg
DELETED
Binary file (62.9 kB)
|
|
data/test_images/16068.jpg
DELETED
Binary file (79.2 kB)
|
|
data/test_images/208078.jpg
DELETED
Binary file (81.3 kB)
|
|
data/test_images/223060.jpg
DELETED
Binary file (82.5 kB)
|
|
data/test_images/388006.jpg
DELETED
Binary file (76.2 kB)
|
|
data/test_images/78098.jpg
DELETED
Binary file (103 kB)
|
|
models/week0417/__pycache__/loss.cpython-38.pyc
ADDED
Binary file (14 kB). View file
|
|
models/week0417/__pycache__/meanshift_utils.cpython-37.pyc
ADDED
Binary file (1.78 kB). View file
|
|
models/week0417/__pycache__/meanshift_utils.cpython-38.pyc
ADDED
Binary file (1.79 kB). View file
|
|
models/week0417/__pycache__/model.cpython-37.pyc
CHANGED
Binary files a/models/week0417/__pycache__/model.cpython-37.pyc and b/models/week0417/__pycache__/model.cpython-37.pyc differ
|
|
models/week0417/__pycache__/model.cpython-38.pyc
CHANGED
Binary files a/models/week0417/__pycache__/model.cpython-38.pyc and b/models/week0417/__pycache__/model.cpython-38.pyc differ
|
|
models/week0417/__pycache__/nnutils.cpython-38.pyc
CHANGED
Binary files a/models/week0417/__pycache__/nnutils.cpython-38.pyc and b/models/week0417/__pycache__/nnutils.cpython-38.pyc differ
|
|
models/week0417/__pycache__/taming_blocks.cpython-38.pyc
CHANGED
Binary files a/models/week0417/__pycache__/taming_blocks.cpython-38.pyc and b/models/week0417/__pycache__/taming_blocks.cpython-38.pyc differ
|
|
models/week0417/meanshift_utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
|
4 |
+
def pairwise_distances(x, y):
|
5 |
+
#Input: x is a Nxd matrix
|
6 |
+
# y is an optional Mxd matirx
|
7 |
+
#Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
|
8 |
+
# if y is not given then use 'y=x'.
|
9 |
+
#i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
|
10 |
+
x_norm = (x ** 2).sum(1).view(-1, 1)
|
11 |
+
y_t = torch.transpose(y, 0, 1)
|
12 |
+
y_norm = (y ** 2).sum(1).view(1, -1)
|
13 |
+
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
|
14 |
+
return torch.clamp(dist, 0.0, np.inf)
|
15 |
+
|
16 |
+
def meanshift_cluster(pts, bandwidth, weights = None, meanshift_step = 15, step_size = 0.3):
|
17 |
+
"""
|
18 |
+
meanshift written in pytorch
|
19 |
+
:param pts: input points
|
20 |
+
:param weights: weight per point during clustering
|
21 |
+
:return: clustered points
|
22 |
+
"""
|
23 |
+
pts_steps = []
|
24 |
+
for i in range(meanshift_step):
|
25 |
+
Y = pairwise_distances(pts, pts)
|
26 |
+
K = torch.nn.functional.relu(bandwidth ** 2 - Y)
|
27 |
+
if weights is not None:
|
28 |
+
K = K * weights
|
29 |
+
P = torch.nn.functional.normalize(K, p=1, dim=0, eps=1e-10)
|
30 |
+
P = P.transpose(0, 1)
|
31 |
+
pts = step_size * (torch.matmul(P, pts) - pts) + pts
|
32 |
+
pts_steps.append(pts)
|
33 |
+
return pts_steps
|
34 |
+
|
35 |
+
def distance(a,b):
|
36 |
+
return torch.sqrt(((a-b)**2).sum())
|
37 |
+
|
38 |
+
def meanshift_assign(points, bandwidth):
|
39 |
+
cluster_ids = []
|
40 |
+
cluster_idx = 0
|
41 |
+
cluster_centers = []
|
42 |
+
|
43 |
+
for i, point in enumerate(points):
|
44 |
+
if(len(cluster_ids) == 0):
|
45 |
+
cluster_ids.append(cluster_idx)
|
46 |
+
cluster_centers.append(point)
|
47 |
+
cluster_idx += 1
|
48 |
+
else:
|
49 |
+
# assign to nearest cluster
|
50 |
+
#for j,center in enumerate(cluster_centers):
|
51 |
+
# dist = distance(point, center)
|
52 |
+
# if(dist < bandwidth):
|
53 |
+
# cluster_ids.append(j)
|
54 |
+
cdist = torch.cdist(point.unsqueeze(0), torch.stack(cluster_centers), p = 2)
|
55 |
+
nearest_idx = torch.argmin(cdist, dim = 1)
|
56 |
+
if cdist[0, nearest_idx] < bandwidth:
|
57 |
+
cluster_ids.append(nearest_idx)
|
58 |
+
else:
|
59 |
+
cluster_ids.append(cluster_idx)
|
60 |
+
cluster_centers.append(point)
|
61 |
+
cluster_idx += 1
|
62 |
+
return cluster_ids, cluster_centers
|
models/week0417/model.py
CHANGED
@@ -7,21 +7,26 @@ import torchvision.transforms as transforms
|
|
7 |
import torchvision.transforms.functional as TF
|
8 |
|
9 |
from .taming_blocks import Encoder
|
|
|
10 |
from .nnutils import SPADEResnetBlock, get_edges, initWave
|
11 |
|
12 |
from libs.nnutils import poolfeat, upfeat
|
13 |
from libs.utils import label2one_hot_torch
|
|
|
14 |
|
15 |
from swapae.models.networks.stylegan2_layers import ConvLayer
|
16 |
from torch_geometric.nn import GCNConv
|
17 |
from torch_geometric.utils import softmax
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
class GCN(nn.Module):
|
21 |
-
def __init__(self, n_cluster, temperature = 1,
|
22 |
super().__init__()
|
23 |
-
self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops =
|
24 |
-
self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops =
|
25 |
self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
|
26 |
self.temperature = temperature
|
27 |
|
@@ -65,10 +70,7 @@ class GCN(nn.Module):
|
|
65 |
|
66 |
# compute texture code w.r.t grouping
|
67 |
pool_feat = poolfeat(conv_feat, s_, avg = True)
|
68 |
-
# hard upsampling
|
69 |
-
#hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1])
|
70 |
feat = upfeat(pool_feat, s_)
|
71 |
-
#feat = upfeat(pool_feat, hard_s_)
|
72 |
|
73 |
prop_code.append(feat)
|
74 |
sp_assign.append(pred_mask)
|
@@ -137,10 +139,7 @@ class AE(nn.Module):
|
|
137 |
# encoder & decoder
|
138 |
self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
|
139 |
in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
|
140 |
-
|
141 |
-
self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim)
|
142 |
-
else:
|
143 |
-
self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim)
|
144 |
|
145 |
self.add_module(
|
146 |
"ToTexCode",
|
@@ -150,40 +149,61 @@ class AE(nn.Module):
|
|
150 |
ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
|
151 |
)
|
152 |
)
|
153 |
-
self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature,
|
154 |
|
155 |
self.add_gcn_epoch = args.add_gcn_epoch
|
156 |
self.add_clustering_epoch = args.add_clustering_epoch
|
157 |
self.add_texture_epoch = args.add_texture_epoch
|
158 |
|
159 |
self.patch_size = args.patch_size
|
|
|
160 |
self.sine_wave_dim = args.spatial_code_dim
|
|
|
|
|
161 |
|
162 |
# inpainting network
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
nn.
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
187 |
c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
|
188 |
c = c.to(GL.device)
|
189 |
# c: 1, 2, 28, 28
|
@@ -192,13 +212,170 @@ class AE(nn.Module):
|
|
192 |
period = self.learnedWN(GL)
|
193 |
# period: 1, 64, 28, 28
|
194 |
raw = period * c
|
195 |
-
|
196 |
-
|
197 |
-
|
198 |
-
|
199 |
-
|
200 |
-
|
|
|
|
|
|
|
|
|
|
|
201 |
return wave
|
202 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
203 |
def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
|
204 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
import torchvision.transforms.functional as TF
|
8 |
|
9 |
from .taming_blocks import Encoder
|
10 |
+
from .loss import styleLossMaskv3
|
11 |
from .nnutils import SPADEResnetBlock, get_edges, initWave
|
12 |
|
13 |
from libs.nnutils import poolfeat, upfeat
|
14 |
from libs.utils import label2one_hot_torch
|
15 |
+
from .meanshift_utils import meanshift_cluster, meanshift_assign
|
16 |
|
17 |
from swapae.models.networks.stylegan2_layers import ConvLayer
|
18 |
from torch_geometric.nn import GCNConv
|
19 |
from torch_geometric.utils import softmax
|
20 |
+
|
21 |
+
import sys
|
22 |
+
sys.path.append('models/third_party/cython')
|
23 |
+
from connectivity import enforce_connectivity
|
24 |
|
25 |
class GCN(nn.Module):
|
26 |
+
def __init__(self, n_cluster, temperature = 1, hidden_dim = 256):
|
27 |
super().__init__()
|
28 |
+
self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
|
29 |
+
self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
|
30 |
self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
|
31 |
self.temperature = temperature
|
32 |
|
|
|
70 |
|
71 |
# compute texture code w.r.t grouping
|
72 |
pool_feat = poolfeat(conv_feat, s_, avg = True)
|
|
|
|
|
73 |
feat = upfeat(pool_feat, s_)
|
|
|
74 |
|
75 |
prop_code.append(feat)
|
76 |
sp_assign.append(pred_mask)
|
|
|
139 |
# encoder & decoder
|
140 |
self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
|
141 |
in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
|
142 |
+
self.G = SPADEGenerator(args.spatial_code_dim + 32, args.hidden_dim)
|
|
|
|
|
|
|
143 |
|
144 |
self.add_module(
|
145 |
"ToTexCode",
|
|
|
149 |
ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
|
150 |
)
|
151 |
)
|
152 |
+
self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, hidden_dim = args.hidden_dim)
|
153 |
|
154 |
self.add_gcn_epoch = args.add_gcn_epoch
|
155 |
self.add_clustering_epoch = args.add_clustering_epoch
|
156 |
self.add_texture_epoch = args.add_texture_epoch
|
157 |
|
158 |
self.patch_size = args.patch_size
|
159 |
+
self.style_loss = styleLossMaskv3(device = args.device)
|
160 |
self.sine_wave_dim = args.spatial_code_dim
|
161 |
+
self.noise_dim = 32
|
162 |
+
self.spatial_code_dim = args.spatial_code_dim
|
163 |
|
164 |
# inpainting network
|
165 |
+
if args.spatial_code_dim > 0:
|
166 |
+
self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
|
167 |
+
|
168 |
+
self.add_module(
|
169 |
+
"Amplitude",
|
170 |
+
nn.Sequential(
|
171 |
+
nn.Conv2d(args.hidden_dim, args.hidden_dim//2, 1, 1, 0),
|
172 |
+
nn.Conv2d(args.hidden_dim//2, args.hidden_dim//4, 1, 1, 0),
|
173 |
+
nn.Conv2d(args.hidden_dim//4, args.spatial_code_dim, 1, 1, 0)
|
174 |
+
)
|
175 |
+
)
|
176 |
+
|
177 |
+
self.bandwidth = 3.0
|
178 |
+
|
179 |
+
def sample_patch_from_mask(self, mask, patch_num = 10, patch_size = 64):
|
180 |
+
"""
|
181 |
+
- Sample `patch_num` patches of size `patch_size*patch_size` w.r.t given mask
|
182 |
+
"""
|
183 |
+
nonzeros = torch.nonzero(mask.view(-1)).squeeze()
|
184 |
+
n = len(nonzeros)
|
185 |
+
xys = []
|
186 |
+
imgH, imgW = mask.shape
|
187 |
+
half_patch = patch_size // 2
|
188 |
+
iter_num = 0
|
189 |
+
while len(xys) < patch_num:
|
190 |
+
id = (torch.ones(n)*1.0/n).multinomial(num_samples=1, replacement=False)
|
191 |
+
rx = nonzeros[id] // imgW
|
192 |
+
ry = nonzeros[id] % imgW
|
193 |
+
top = max(0, rx - half_patch)
|
194 |
+
bot = min(imgH, rx + half_patch)
|
195 |
+
left = max(0, ry - half_patch)
|
196 |
+
right = min(imgW, ry + half_patch)
|
197 |
+
patch_mask = mask[top:bot, left:right]
|
198 |
+
if torch.sum(patch_mask) / (patch_size ** 2) > 0.5 or iter_num > 20:
|
199 |
+
xys.append([top, bot, left, right])
|
200 |
+
iter_num += 1
|
201 |
+
return xys
|
202 |
+
|
203 |
+
def get_sine_wave(self, GL, offset_mode = 'rec'):
|
204 |
+
imgH, imgW = GL.shape[-2]//8, GL.shape[-1] // 8
|
205 |
+
GL = F.interpolate(GL, size = (imgH, imgW), mode = 'nearest')
|
206 |
+
xv, yv = np.meshgrid(np.arange(imgH), np.arange(imgW),indexing='ij')
|
207 |
c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
|
208 |
c = c.to(GL.device)
|
209 |
# c: 1, 2, 28, 28
|
|
|
212 |
period = self.learnedWN(GL)
|
213 |
# period: 1, 64, 28, 28
|
214 |
raw = period * c
|
215 |
+
|
216 |
+
# random offset
|
217 |
+
roffset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
|
218 |
+
roffset = roffset.repeat(1, 1, imgH, imgW)
|
219 |
+
rwave = torch.sin(raw[:, ::2] + raw[:, 1::2] + roffset)
|
220 |
+
|
221 |
+
# zero offset
|
222 |
+
zwave = torch.sin(raw[:, ::2] + raw[:, 1::2])
|
223 |
+
A = self.Amplitude(GL)
|
224 |
+
A = torch.sigmoid(A)
|
225 |
+
wave = torch.cat((zwave, rwave)) * A.repeat(2, 1, 1, 1)
|
226 |
return wave
|
227 |
|
228 |
+
def syn_tex(self, tex_code, mask, imgH, imgW, offset_mode = 'rec', tex_idx = None):
|
229 |
+
# synthesize all textures
|
230 |
+
# spatial: B x 256 x 14 x 14
|
231 |
+
# tex_code: B x N x 256
|
232 |
+
B, N, _ = tex_code.shape
|
233 |
+
H = imgH // 8
|
234 |
+
W = imgW // 8
|
235 |
+
|
236 |
+
# randomly sample a texture and synthesize it
|
237 |
+
# throw away small texture segments
|
238 |
+
areas = torch.sum(mask, dim=(2, 3))
|
239 |
+
valid_idxs = torch.nonzero(areas[0] / (imgH * imgW) > 0.01).squeeze(-1)
|
240 |
+
if tex_idx is None or tex_idx >= tex_code.shape[1]:
|
241 |
+
tex_idx = valid_idxs[torch.multinomial(areas[0, valid_idxs], 1).squeeze()]
|
242 |
+
else:
|
243 |
+
sorted_list = torch.argsort(areas, dim = 1, descending = True)
|
244 |
+
tex_idx = sorted_list[0, tex_idx]
|
245 |
+
sampled_code = tex_code[:, tex_idx, :]
|
246 |
+
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, imgH, imgW)
|
247 |
+
|
248 |
+
# Decoder: Spatial & Texture code -> Image
|
249 |
+
if self.noise_dim == 0:
|
250 |
+
dec_input = self.get_sine_wave(rec_tex, offset_mode)
|
251 |
+
elif self.spatial_code_dim == 0:
|
252 |
+
dec_input = torch.randn(rec_tex.shape[0], self.noise_dim, H, W).to(tex_code.device)
|
253 |
+
else:
|
254 |
+
sine_wave = self.get_sine_wave(rec_tex, offset_mode)
|
255 |
+
noise = torch.randn(sine_wave.shape[0], self.noise_dim, H, W).to(tex_code.device)
|
256 |
+
dec_input = torch.cat((sine_wave, noise), dim = 1)
|
257 |
+
|
258 |
+
tex_syn = self.G(dec_input, rec_tex.repeat(dec_input.shape[0], 1, 1, 1))
|
259 |
+
|
260 |
+
return tex_syn, tex_idx
|
261 |
+
|
262 |
+
def sample_tex_patches(self, tex_idx, rgb_img, rep_rec, mask, patch_num = 10):
|
263 |
+
patches = []
|
264 |
+
masks = []
|
265 |
+
patch_masks = []
|
266 |
+
# sample patches from input image and reconstruction
|
267 |
+
for i in range(rgb_img.shape[0]):
|
268 |
+
# WARNING: : This only works for batch_size = 1 for now
|
269 |
+
maski = mask[i, tex_idx]
|
270 |
+
masks.append(maski.unsqueeze(0))
|
271 |
+
xys = self.sample_patch_from_mask(maski, patch_num = patch_num, patch_size = self.patch_size)
|
272 |
+
# sample 10 patches from input image & reconstruction w.r.t group mask
|
273 |
+
for k in range(patch_num):
|
274 |
+
top, bot, left, right = xys[k]
|
275 |
+
patch_ = rgb_img[i, :, top:bot, left:right]
|
276 |
+
patch_mask_ = maski[top:bot, left:right]
|
277 |
+
|
278 |
+
# In case the patch is on the boundary and smaller than patch_size
|
279 |
+
# We put the patch at some random place of a black image
|
280 |
+
h, w = patch_.shape[-2:]
|
281 |
+
x = 0; y = 0
|
282 |
+
if h < self.patch_size:
|
283 |
+
x = np.random.randint(0, self.patch_size - h)
|
284 |
+
if w < self.patch_size:
|
285 |
+
y = np.random.randint(0, self.patch_size - w)
|
286 |
+
patch = torch.zeros(1, 3, self.patch_size, self.patch_size).to(patch_.device)
|
287 |
+
patch_mask = torch.zeros(1, 1, self.patch_size, self.patch_size).to(patch_.device)
|
288 |
+
patch[:, :, x:x+h, y:y+w] = patch_
|
289 |
+
patch_mask[:, :, x:x+h, y:y+w] = patch_mask_
|
290 |
+
patches.append(patch)
|
291 |
+
patch_masks.append(patch_mask)
|
292 |
+
patches = torch.cat(patches)
|
293 |
+
masks = torch.stack(masks)
|
294 |
+
patch_masks = torch.cat(patch_masks)
|
295 |
+
|
296 |
+
# sample patches from synthesized texture
|
297 |
+
tex_patch_size = self.patch_size
|
298 |
+
rep_patches = []
|
299 |
+
for k in range(patch_num):
|
300 |
+
i, j, h, w = transforms.RandomCrop.get_params(rep_rec, output_size=(tex_patch_size, tex_patch_size))
|
301 |
+
rep_rec_patch = TF.crop(rep_rec, i, j, h, w)
|
302 |
+
rep_patches.append(rep_rec_patch)
|
303 |
+
rep_patches = torch.stack(rep_patches, dim = 1)
|
304 |
+
rep_patches = rep_patches.view(-1, 3, tex_patch_size, tex_patch_size)
|
305 |
+
return masks, patch_masks, patches, rep_patches
|
306 |
+
|
307 |
def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
|
308 |
+
#self.patch_size = np.random.randint(64, 160)
|
309 |
+
B, _, imgH, imgW = rgb_img.shape
|
310 |
+
outputs = {}
|
311 |
+
rec_feat_list = []
|
312 |
+
seg_map = [torch.argmax(slic.cpu(), dim = 1)]
|
313 |
+
|
314 |
+
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
|
315 |
+
conv_feat, layer_feats = self.enc(rgb_img)
|
316 |
+
B, C, H, W = conv_feat.shape
|
317 |
+
|
318 |
+
# Texture code for each superpixel
|
319 |
+
tex_code = self.ToTexCode(conv_feat)
|
320 |
+
|
321 |
+
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
|
322 |
+
pool_code = poolfeat(code, slic, avg = True)
|
323 |
+
|
324 |
+
if epoch >= self.add_gcn_epoch:
|
325 |
+
prop_code, sp_assign, conv_feats = self.gcn(pool_code, slic, (self.add_clustering_epoch <= epoch))
|
326 |
+
softmax = F.softmax(sp_assign * self.gcn.temperature, dim = 1)
|
327 |
+
rec_feat_list.append(prop_code)
|
328 |
+
seg_map = [torch.argmax(sp_assign.cpu(), dim = 1)]
|
329 |
+
else:
|
330 |
+
rec_code = upfeat(pool_code, slic)
|
331 |
+
rec_feat_list.append(rec_code)
|
332 |
+
softmax = slic
|
333 |
+
|
334 |
+
# Texture synthesis
|
335 |
+
if epoch >= self.add_texture_epoch:
|
336 |
+
sp_feat = poolfeat(conv_feats, slic, avg = True).squeeze(0)
|
337 |
+
pts = meanshift_cluster(sp_feat, self.bandwidth, meanshift_step = 15)[-1]
|
338 |
+
with torch.no_grad():
|
339 |
+
sp_assign, _ = meanshift_assign(pts, self.bandwidth)
|
340 |
+
sp_assign = torch.tensor(sp_assign).unsqueeze(-1).to(slic.device).float()
|
341 |
+
sp_assign = upfeat(sp_assign, slic)
|
342 |
+
seg = label2one_hot_torch(sp_assign, C = sp_assign.max().long() + 1)
|
343 |
+
seg_map = [torch.argmax(seg.cpu(), dim = 1)]
|
344 |
+
|
345 |
+
# texture code for each connected group
|
346 |
+
tex_seg = poolfeat(conv_feats, seg, avg = True)
|
347 |
+
if test:
|
348 |
+
rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 564, 564, tex_idx = tex_idx)
|
349 |
+
#rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 1024, 1024, tex_idx = tex_idx)
|
350 |
+
else:
|
351 |
+
rep_rec, tex_idx = self.syn_tex(tex_seg, seg, imgH, imgW, tex_idx = tex_idx)
|
352 |
+
rep_rec = (rep_rec + 1) / 2.0
|
353 |
+
rgb_img = (rgb_img + 1) / 2.0
|
354 |
+
|
355 |
+
# sample patches from input image, reconstruction & synthesized texture
|
356 |
+
# zero offset
|
357 |
+
zmasks, zpatch_masks, zpatches, zrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[:1], seg)
|
358 |
+
# random offset
|
359 |
+
rmasks, rpatch_masks, rpatches, rrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[1:], seg)
|
360 |
+
masks = torch.cat((zmasks, rmasks))
|
361 |
+
patch_masks = torch.cat((zpatch_masks, rpatch_masks))
|
362 |
+
patches = torch.cat((zpatches, rpatches))
|
363 |
+
rep_patches = torch.cat((zrep_patches, rrep_patches))
|
364 |
+
|
365 |
+
# Gram matrix matching loss between:
|
366 |
+
# - patches from synthesized texture v.s. patches from input image
|
367 |
+
# - patches from reconstruction v.s. patches from input image
|
368 |
+
outputs['style_loss'] = self.style_loss.forward_patch_img(rep_patches, rgb_img.repeat(2, 1, 1, 1), masks)
|
369 |
+
|
370 |
+
outputs['rep_rec'] = rep_rec
|
371 |
+
outputs['masks'] = masks
|
372 |
+
outputs['patches'] = patches.view(-1, 3, self.patch_size, self.patch_size)
|
373 |
+
outputs['patch_masks'] = patch_masks
|
374 |
+
outputs['rep_patches'] = rep_patches * patch_masks + patches * (1 - patch_masks)
|
375 |
+
|
376 |
+
outputs['gt'] = rgb_img
|
377 |
+
bp_tex = rep_rec[:1, :, :imgH, :imgW] * masks[:1] + rgb_img * (1 - masks[:1])
|
378 |
+
outputs['rec'] = bp_tex
|
379 |
+
|
380 |
+
outputs['HA'] = torch.cat(seg_map)
|
381 |
+
return outputs
|
models/week0417/model_bk.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
import torchvision.transforms.functional as TF
|
8 |
+
|
9 |
+
from .taming_blocks import Encoder
|
10 |
+
from .nnutils import SPADEResnetBlock, get_edges, initWave
|
11 |
+
|
12 |
+
from libs.nnutils import poolfeat, upfeat
|
13 |
+
from libs.utils import label2one_hot_torch
|
14 |
+
|
15 |
+
from swapae.models.networks.stylegan2_layers import ConvLayer
|
16 |
+
from torch_geometric.nn import GCNConv
|
17 |
+
from torch_geometric.utils import softmax
|
18 |
+
from .loss import styleLossMaskv3
|
19 |
+
|
20 |
+
class GCN(nn.Module):
|
21 |
+
def __init__(self, n_cluster, temperature = 1, add_self_loops = True, hidden_dim = 256):
|
22 |
+
super().__init__()
|
23 |
+
self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
|
24 |
+
self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = add_self_loops)
|
25 |
+
self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
|
26 |
+
self.temperature = temperature
|
27 |
+
|
28 |
+
def compute_edge_score_softmax(self, raw_edge_score, edge_index, num_nodes):
|
29 |
+
return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)
|
30 |
+
|
31 |
+
def compute_edge_weight(self, node_feature, edge_index):
|
32 |
+
src_feat = torch.gather(node_feature, 0, edge_index[0].unsqueeze(1).repeat(1, node_feature.shape[1]))
|
33 |
+
tgt_feat = torch.gather(node_feature, 0, edge_index[1].unsqueeze(1).repeat(1, node_feature.shape[1]))
|
34 |
+
raw_edge_weight = nn.CosineSimilarity(dim=1, eps=1e-6)(src_feat, tgt_feat)
|
35 |
+
edge_weight = self.compute_edge_score_softmax(raw_edge_weight, edge_index, node_feature.shape[0])
|
36 |
+
return raw_edge_weight.squeeze(), edge_weight.squeeze()
|
37 |
+
|
38 |
+
def forward(self, sp_code, slic, clustering = False):
|
39 |
+
edges, aff = get_edges(torch.argmax(slic, dim = 1).unsqueeze(1), sp_code.shape[1])
|
40 |
+
prop_code = []
|
41 |
+
sp_assign = []
|
42 |
+
edge_weights = []
|
43 |
+
conv_feats = []
|
44 |
+
for i in range(sp_code.shape[0]):
|
45 |
+
# compute edge weight
|
46 |
+
edge_index = edges[i]
|
47 |
+
raw_edge_weight, edge_weight = self.compute_edge_weight(sp_code[i], edge_index)
|
48 |
+
feat = self.gcnconv1(sp_code[i], edge_index, edge_weight = edge_weight)
|
49 |
+
raw_edge_weight, edge_weight = self.compute_edge_weight(feat, edge_index)
|
50 |
+
edge_weights.append(raw_edge_weight)
|
51 |
+
feat = F.leaky_relu(feat, 0.2)
|
52 |
+
feat = self.gcnconv2(feat, edge_index, edge_weight = edge_weight)
|
53 |
+
|
54 |
+
# maybe clustering
|
55 |
+
conv_feat = upfeat(feat, slic[i:i+1])
|
56 |
+
conv_feats.append(conv_feat)
|
57 |
+
if not clustering:
|
58 |
+
feat = conv_feat
|
59 |
+
pred_mask = slic[i:i+1]
|
60 |
+
else:
|
61 |
+
pred_mask = self.pool1(conv_feat)
|
62 |
+
# enforce pixels belong to the same superpixel to have same grouping label
|
63 |
+
pred_mask = upfeat(poolfeat(pred_mask, slic[i:i+1]), slic[i:i+1])
|
64 |
+
s_ = F.softmax(pred_mask * self.temperature, dim = 1)
|
65 |
+
|
66 |
+
# compute texture code w.r.t grouping
|
67 |
+
pool_feat = poolfeat(conv_feat, s_, avg = True)
|
68 |
+
# hard upsampling
|
69 |
+
#hard_s_ = label2one_hot_torch(torch.argmax(s_, dim = 1).unsqueeze(1), C = s_.shape[1])
|
70 |
+
feat = upfeat(pool_feat, s_)
|
71 |
+
#feat = upfeat(pool_feat, hard_s_)
|
72 |
+
|
73 |
+
prop_code.append(feat)
|
74 |
+
sp_assign.append(pred_mask)
|
75 |
+
prop_code = torch.cat(prop_code)
|
76 |
+
conv_feats = torch.cat(conv_feats)
|
77 |
+
return prop_code, torch.cat(sp_assign), conv_feats
|
78 |
+
|
79 |
+
class SPADEGenerator(nn.Module):
|
80 |
+
def __init__(self, in_dim, hidden_dim):
|
81 |
+
super().__init__()
|
82 |
+
nf = hidden_dim // 16
|
83 |
+
|
84 |
+
self.head_0 = SPADEResnetBlock(in_dim, 16 * nf)
|
85 |
+
|
86 |
+
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
|
87 |
+
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)
|
88 |
+
|
89 |
+
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
|
90 |
+
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
|
91 |
+
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf)
|
92 |
+
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)
|
93 |
+
|
94 |
+
final_nc = nf
|
95 |
+
|
96 |
+
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
|
97 |
+
|
98 |
+
self.up = nn.Upsample(scale_factor=2)
|
99 |
+
|
100 |
+
def forward(self, sine_wave, texon):
|
101 |
+
|
102 |
+
x = self.head_0(sine_wave, texon)
|
103 |
+
|
104 |
+
x = self.up(x)
|
105 |
+
x = self.G_middle_0(x, texon)
|
106 |
+
x = self.G_middle_1(x, texon)
|
107 |
+
|
108 |
+
x = self.up(x)
|
109 |
+
x = self.up_0(x, texon)
|
110 |
+
x = self.up(x)
|
111 |
+
x = self.up_1(x, texon)
|
112 |
+
#x = self.up(x)
|
113 |
+
x = self.up_2(x, texon)
|
114 |
+
#x = self.up(x)
|
115 |
+
x = self.up_3(x, texon)
|
116 |
+
|
117 |
+
x = self.conv_img(F.leaky_relu(x, 2e-1))
|
118 |
+
return x
|
119 |
+
|
120 |
+
class Waver(nn.Module):
|
121 |
+
def __init__(self, tex_code_dim, zPeriodic):
|
122 |
+
super(Waver, self).__init__()
|
123 |
+
K = tex_code_dim
|
124 |
+
layers = [nn.Conv2d(tex_code_dim, K, 1)]
|
125 |
+
layers += [nn.ReLU(True)]
|
126 |
+
layers += [nn.Conv2d(K, 2 * zPeriodic, 1)]
|
127 |
+
self.learnedWN = nn.Sequential(*layers)
|
128 |
+
self.waveNumbers = initWave(zPeriodic)
|
129 |
+
|
130 |
+
def forward(self, GLZ=None):
|
131 |
+
return (self.waveNumbers.to(GLZ.device) + self.learnedWN(GLZ))
|
132 |
+
|
133 |
+
class AE(nn.Module):
|
134 |
+
def __init__(self, args, **ignore_kwargs):
|
135 |
+
super(AE, self).__init__()
|
136 |
+
|
137 |
+
# encoder & decoder
|
138 |
+
self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
|
139 |
+
in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
|
140 |
+
if args.dec_input_mode == 'sine_wave_noise':
|
141 |
+
self.G = SPADEGenerator(args.spatial_code_dim * 2, args.hidden_dim)
|
142 |
+
else:
|
143 |
+
self.G = SPADEGenerator(args.spatial_code_dim, args.hidden_dim)
|
144 |
+
|
145 |
+
self.add_module(
|
146 |
+
"ToTexCode",
|
147 |
+
nn.Sequential(
|
148 |
+
ConvLayer(args.hidden_dim, args.hidden_dim, kernel_size=3, activate=True, bias=True),
|
149 |
+
ConvLayer(args.hidden_dim, args.tex_code_dim, kernel_size=3, activate=True, bias=True),
|
150 |
+
ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
|
151 |
+
)
|
152 |
+
)
|
153 |
+
self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, add_self_loops = (args.add_self_loops == 1), hidden_dim = args.hidden_dim)
|
154 |
+
|
155 |
+
self.add_gcn_epoch = args.add_gcn_epoch
|
156 |
+
self.add_clustering_epoch = args.add_clustering_epoch
|
157 |
+
self.add_texture_epoch = args.add_texture_epoch
|
158 |
+
|
159 |
+
self.patch_size = args.patch_size
|
160 |
+
self.sine_wave_dim = args.spatial_code_dim
|
161 |
+
|
162 |
+
# inpainting network
|
163 |
+
self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
|
164 |
+
self.dec_input_mode = args.dec_input_mode
|
165 |
+
self.style_loss = styleLossMaskv3(device = args.device)
|
166 |
+
|
167 |
+
if args.sine_weight:
|
168 |
+
if args.dec_input_mode == 'sine_wave_noise':
|
169 |
+
self.add_module(
|
170 |
+
"ChannelWeight",
|
171 |
+
nn.Sequential(
|
172 |
+
ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
|
173 |
+
ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
|
174 |
+
ConvLayer(args.hidden_dim//4, args.spatial_code_dim*2, kernel_size=1, activate=False, bias=False, downsample=True)))
|
175 |
+
else:
|
176 |
+
self.add_module(
|
177 |
+
"ChannelWeight",
|
178 |
+
nn.Sequential(
|
179 |
+
ConvLayer(args.hidden_dim, args.hidden_dim//2, kernel_size=3, activate=True, bias=True, downsample=True),
|
180 |
+
ConvLayer(args.hidden_dim//2, args.hidden_dim//4, kernel_size=3, activate=True, bias=True, downsample=True),
|
181 |
+
ConvLayer(args.hidden_dim//4, args.spatial_code_dim, kernel_size=1, activate=False, bias=False, downsample=True)))
|
182 |
+
|
183 |
+
def get_sine_wave(self, GL, offset_mode = 'random'):
|
184 |
+
img_size = GL.shape[-1] // 8
|
185 |
+
GL = F.interpolate(GL, size = (img_size, img_size), mode = 'nearest')
|
186 |
+
xv, yv = np.meshgrid(np.arange(img_size), np.arange(img_size),indexing='ij')
|
187 |
+
c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
|
188 |
+
c = c.to(GL.device)
|
189 |
+
# c: 1, 2, 28, 28
|
190 |
+
c = c.repeat(GL.shape[0], self.sine_wave_dim, 1, 1)
|
191 |
+
# c: 1, 64, 28, 28
|
192 |
+
period = self.learnedWN(GL)
|
193 |
+
# period: 1, 64, 28, 28
|
194 |
+
raw = period * c
|
195 |
+
if offset_mode == 'random':
|
196 |
+
offset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
|
197 |
+
offset = offset.repeat(1, 1, img_size, img_size)
|
198 |
+
wave = torch.sin(raw[:, ::2] + raw[:, 1::2] + offset)
|
199 |
+
elif offset_mode == 'rec':
|
200 |
+
wave = torch.sin(raw[:, ::2] + raw[:, 1::2])
|
201 |
+
return wave
|
202 |
+
|
203 |
+
def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
|
204 |
+
return
|
swapae/models/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/swapae/models/__pycache__/__init__.cpython-38.pyc and b/swapae/models/__pycache__/__init__.cpython-38.pyc differ
|
|
swapae/models/__pycache__/base_model.cpython-38.pyc
CHANGED
Binary files a/swapae/models/__pycache__/base_model.cpython-38.pyc and b/swapae/models/__pycache__/base_model.cpython-38.pyc differ
|
|
swapae/models/networks/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/__pycache__/__init__.cpython-38.pyc and b/swapae/models/networks/__pycache__/__init__.cpython-38.pyc differ
|
|
swapae/models/networks/__pycache__/base_network.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/__pycache__/base_network.cpython-38.pyc and b/swapae/models/networks/__pycache__/base_network.cpython-38.pyc differ
|
|
swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc and b/swapae/models/networks/__pycache__/stylegan2_layers.cpython-38.pyc differ
|
|
swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/__init__.cpython-38.pyc differ
|
|
swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/fused_act.cpython-38.pyc differ
|
|
swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc
CHANGED
Binary files a/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc and b/swapae/models/networks/stylegan2_op/__pycache__/upfirdn2d.cpython-38.pyc differ
|
|
swapae/util/__pycache__/__init__.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/__init__.cpython-38.pyc and b/swapae/util/__pycache__/__init__.cpython-38.pyc differ
|
|
swapae/util/__pycache__/html.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/html.cpython-38.pyc and b/swapae/util/__pycache__/html.cpython-38.pyc differ
|
|
swapae/util/__pycache__/iter_counter.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/iter_counter.cpython-38.pyc and b/swapae/util/__pycache__/iter_counter.cpython-38.pyc differ
|
|
swapae/util/__pycache__/metric_tracker.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/metric_tracker.cpython-38.pyc and b/swapae/util/__pycache__/metric_tracker.cpython-38.pyc differ
|
|
swapae/util/__pycache__/util.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/util.cpython-38.pyc and b/swapae/util/__pycache__/util.cpython-38.pyc differ
|
|
swapae/util/__pycache__/visualizer.cpython-38.pyc
CHANGED
Binary files a/swapae/util/__pycache__/visualizer.cpython-38.pyc and b/swapae/util/__pycache__/visualizer.cpython-38.pyc differ
|
|
tmp/0.png
ADDED
tmp/1.png
ADDED
tmp/2.png
ADDED
tmp/3.png
ADDED
tmp/4.png
ADDED
tmp/5.png
ADDED
tmp/6.png
ADDED
tmp/7.png
ADDED
tmp/8.png
ADDED
tmp/9.png
ADDED
tmp/tmp.png
ADDED
weights/108004/exp_args.json
DELETED
@@ -1,63 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"data_path": "/home/xtli/DATA/BSR_processed/train",
|
3 |
-
"img_path": "data/test_images/108004.jpg",
|
4 |
-
"test_path": null,
|
5 |
-
"crop_size": 224,
|
6 |
-
"scale_size": null,
|
7 |
-
"batch_size": 1,
|
8 |
-
"workers": 4,
|
9 |
-
"pretrained_path": "/home/xtli/WORKDIR/04-15/single_scale_grouping_resume/cpk.pth",
|
10 |
-
"hidden_dim": 256,
|
11 |
-
"spatial_code_dim": 32,
|
12 |
-
"tex_code_dim": 256,
|
13 |
-
"exp_name": "04-18/108004",
|
14 |
-
"project_name": "ssn_transformer",
|
15 |
-
"nepochs": 20,
|
16 |
-
"lr": 5e-05,
|
17 |
-
"momentum": 0.5,
|
18 |
-
"beta": 0.999,
|
19 |
-
"lr_decay_freq": 3000,
|
20 |
-
"save_freq": 1000,
|
21 |
-
"save_freq_iter": 2000000000000,
|
22 |
-
"log_freq": 10,
|
23 |
-
"display_freq": 100,
|
24 |
-
"use_wandb": 0,
|
25 |
-
"work_dir": "/home/xtli/WORKDIR",
|
26 |
-
"out_dir": "/home/xtli/WORKDIR/04-18/108004",
|
27 |
-
"local_rank": 0,
|
28 |
-
"dataset": "dataset",
|
29 |
-
"config_file": "models/week0417/json/single_scale_grouping_ft.json",
|
30 |
-
"lambda_L1": 1,
|
31 |
-
"lambda_Perceptual": 1.0,
|
32 |
-
"lambda_PatchGAN": 1.0,
|
33 |
-
"lambda_GAN": 1,
|
34 |
-
"add_gan_epoch": 0,
|
35 |
-
"lambda_kld_loss": 1e-06,
|
36 |
-
"lambda_style_loss": 1.0,
|
37 |
-
"lambda_feat": 10.0,
|
38 |
-
"use_slic": true,
|
39 |
-
"patch_size": 64,
|
40 |
-
"netPatchD_scale_capacity": 4.0,
|
41 |
-
"netPatchD_max_nc": 384,
|
42 |
-
"netPatchD": "StyleGAN2",
|
43 |
-
"use_antialias": true,
|
44 |
-
"patch_use_aggregation": false,
|
45 |
-
"lambda_R1": 1.0,
|
46 |
-
"lambda_ffl_loss": 1.0,
|
47 |
-
"lambda_patch_R1": 1.0,
|
48 |
-
"R1_once_every": 16,
|
49 |
-
"add_self_loops": 1,
|
50 |
-
"test_time": false,
|
51 |
-
"sp_num": 196,
|
52 |
-
"label_path": "/home/xtli/DATA/BSR/BSDS500/data/groundTruth",
|
53 |
-
"model_name": "model",
|
54 |
-
"num_D": 2,
|
55 |
-
"n_layers_D": 3,
|
56 |
-
"n_cluster": 10,
|
57 |
-
"temperature": 23,
|
58 |
-
"add_gcn_epoch": 0,
|
59 |
-
"add_clustering_epoch": 0,
|
60 |
-
"add_texture_epoch": 0,
|
61 |
-
"dec_input_mode": "sine_wave_noise",
|
62 |
-
"sine_weight": true
|
63 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|