DmitrMakeev commited on
Commit
98c5805
1 Parent(s): 2e9004e

Upload 9 files

Browse files
models/__init__.py ADDED
File without changes
models/anchor_gen.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.autograd import Function
5
+ from models import basic, clusterkit
6
+ import pdb
7
+
8
+ class AnchorAnalysis:
9
+ def __init__(self, mode, colorLabeler):
10
+ ## anchor generating mode: 1.random; 2.clustering
11
+ self.mode = mode
12
+ self.colorLabeler = colorLabeler
13
+
14
+ def _detect_correlation(self, data_tensors, color_probs, hint_masks, thres=0.1):
15
+ N,C,H,W = data_tensors.shape
16
+ ## (N,C,HW)
17
+ data_vecs = data_tensors.flatten(2)
18
+ prob_vecs = color_probs.flatten(2)
19
+ mask_vecs = hint_masks.flatten(2)
20
+ #anchor_data = torch.masked_select(data_vecs, mask_vecs.bool()).view(N,C,-1)
21
+ #anchor_prob = torch.masked_select(prob_vecs, mask_vecs.bool()).view(N,313,-1)
22
+ #_,_,K = anchor_data.shape
23
+ anchor_mask = torch.matmul(mask_vecs.permute(0,2,1), mask_vecs)
24
+ cosine_sim = True
25
+ ## non-similarity matrix
26
+ if cosine_sim:
27
+ norm_data = F.normalize(data_vecs, p=2, dim=1)
28
+ ## (N,HW,HW) = (N,HW,C) X (N,C,HW)
29
+ corr_matrix = torch.matmul(norm_data.permute(0,2,1), norm_data)
30
+ ## remapping: [-1.0,1.0] to [0.0,1.0], and convert into dis-similarity
31
+ dist_matrix = 1.0 - 0.5*(corr_matrix + 1.0)
32
+ else:
33
+ ## (N,HW,HW) = (N,HW,C) X (N,C,HW)
34
+ XtX = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
35
+ diag_vec = torch.diagonal(XtX, dim1=-2, dim2=-1)
36
+ A = diag_vec.unsqueeze(1).repeat(1,H*W,1)
37
+ At = diag_vec.unsqueeze(2).repeat(1,1,H*W)
38
+ dist_matrix = A - 2*XtX + At
39
+ #dist_matrix = dist_matrix + 1e7*torch.eye(K).to(data_tensors.device).repeat(N,1,1)
40
+ ## for debug use
41
+ K = 8
42
+ anchor_adj_matrix = torch.masked_select(dist_matrix, anchor_mask.bool()).view(N,K,K)
43
+ ## dectect connected nodes
44
+ adj_matrix = torch.where((dist_matrix < thres) & (anchor_mask > 0), torch.ones_like(dist_matrix), torch.zeros_like(dist_matrix))
45
+ adj_matrix = torch.matmul(adj_matrix, adj_matrix)
46
+ adj_matrix = adj_matrix / (1e-7+adj_matrix)
47
+ ## merge nodes
48
+ ## (N,K,C) = (N,K,K) X (N,K,C)
49
+ anchor_prob = torch.matmul(adj_matrix, prob_vecs.permute(0,2,1)) / torch.sum(adj_matrix, dim=2, keepdim=True)
50
+ updated_prob_vecs = anchor_prob.permute(0,2,1) * mask_vecs + (1-mask_vecs) * prob_vecs
51
+ color_probs = updated_prob_vecs.view(N,313,H,W)
52
+ return color_probs, anchor_adj_matrix
53
+
54
+ def _sample_anchor_colors(self, pred_prob, hint_mask, T=0):
55
+ N,C,H,W = pred_prob.shape
56
+ topk = 10
57
+ assert T < topk
58
+ sorted_probs, batch_indexs = torch.sort(pred_prob, dim=1, descending=True)
59
+ ## (N,topk,H,W,1)
60
+ topk_probs = torch.softmax(sorted_probs[:,:topk,:,:], dim=1).unsqueeze(4)
61
+ topk_indexs = batch_indexs[:,:topk,:,:]
62
+ topk_ABs = torch.stack([self.colorLabeler.q_to_ab.index_select(0, q_i.flatten()).reshape(topk,H,W,2)
63
+ for q_i in topk_indexs])
64
+ ## (N,topk,H,W,2)
65
+ topk_ABs = topk_ABs / 110.0
66
+ ## choose the most distinctive 3 colors for each anchor
67
+ if T == 0:
68
+ sampled_ABs = topk_ABs[:,0,:,:,:]
69
+ elif T == 1:
70
+ sampled_AB0 = topk_ABs[:,[0],:,:,:]
71
+ internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
72
+ _, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
73
+ ## (N,1,H,W,2)
74
+ selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
75
+ sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
76
+ sampled_ABs = sampled_ABs.squeeze(1)
77
+ else:
78
+ sampled_AB0 = topk_ABs[:,[0],:,:,:]
79
+ internal_diff = torch.norm(topk_ABs-sampled_AB0, p=2, dim=4, keepdim=True)
80
+ _, batch_indexs = torch.sort(internal_diff, dim=1, descending=True)
81
+ selected_index = batch_indexs[:,[0],:,:,:].expand([-1,-1,-1,-1,2])
82
+ sampled_AB1 = torch.gather(topk_ABs, 1, selected_index)
83
+ internal_diff2 = torch.norm(topk_ABs-sampled_AB1, p=2, dim=4, keepdim=True)
84
+ _, batch_indexs = torch.sort(internal_diff+internal_diff2, dim=1, descending=True)
85
+ ## (N,1,H,W,2)
86
+ selected_index = batch_indexs[:,[T-2],:,:,:].expand([-1,-1,-1,-1,2])
87
+ sampled_ABs = torch.gather(topk_ABs, 1, selected_index)
88
+ sampled_ABs = sampled_ABs.squeeze(1)
89
+
90
+ return sampled_ABs.permute(0,3,1,2)
91
+
92
+ def __call__(self, data_tensors, n_anchors, spixel_sizes, use_sklearn_kmeans=False):
93
+ N,C,H,W = data_tensors.shape
94
+ if self.mode == 'clustering':
95
+ ## clusters map: (N,K,H,W)
96
+ cluster_mask = clusterkit.batch_kmeans_pytorch(data_tensors, n_anchors, 'euclidean', use_sklearn_kmeans)
97
+ #noises = torch.rand(N,1,H,W).to(cluster_mask.device)
98
+ perturb_factors = spixel_sizes
99
+ cluster_prob = cluster_mask + perturb_factors * 0.01
100
+ hint_mask_layers = F.one_hot(torch.argmax(cluster_prob.flatten(2), dim=-1), num_classes=H*W).float()
101
+ hint_mask = torch.sum(hint_mask_layers, dim=1, keepdim=True).view(N,1,H,W)
102
+ else:
103
+ #print('----------hello, random!')
104
+ cluster_mask = torch.zeros(N,n_anchors,H,W).to(data_tensors.device)
105
+ binary_mask = basic.get_random_mask(N, H, W, minNum=n_anchors, maxNum=n_anchors)
106
+ hint_mask = torch.from_numpy(binary_mask).to(data_tensors.device)
107
+ return hint_mask, cluster_mask
models/basic.py ADDED
@@ -0,0 +1,504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import torch.nn.utils.spectral_norm as spectral_norm
6
+ from torch.autograd import Function
7
+ from utils import util, cielab
8
+ import cv2, math, random
9
+
10
+ def tensor2array(tensors):
11
+ arrays = tensors.detach().to("cpu").numpy()
12
+ return np.transpose(arrays, (0, 2, 3, 1))
13
+
14
+
15
+ def rgb2gray(color_batch):
16
+ #! gray = 0.299*R+0.587*G+0.114*B
17
+ gray_batch = color_batch[:, 0, ...] * 0.299 + color_batch[:, 1, ...] * 0.587 + color_batch[:, 2, ...] * 0.114
18
+ gray_batch = gray_batch.unsqueeze_(1)
19
+ return gray_batch
20
+
21
+
22
+ def getParamsAmount(model):
23
+ params = list(model.parameters())
24
+ count = 0
25
+ for var in params:
26
+ l = 1
27
+ for j in var.size():
28
+ l *= j
29
+ count += l
30
+ return count
31
+
32
+
33
+ def checkAverageGradient(model):
34
+ meanGrad, cnt = 0.0, 0
35
+ for name, parms in model.named_parameters():
36
+ if parms.requires_grad:
37
+ meanGrad += torch.mean(torch.abs(parms.grad))
38
+ cnt += 1
39
+ return meanGrad.item() / cnt
40
+
41
+
42
+ def get_random_mask(N, H, W, minNum, maxNum):
43
+ binary_maps = np.zeros((N, H*W), np.float32)
44
+ for i in range(N):
45
+ locs = random.sample(range(0, H*W), random.randint(minNum,maxNum))
46
+ binary_maps[i, locs] = 1
47
+ return binary_maps.reshape(N,1,H,W)
48
+
49
+
50
+ def io_user_control(hint_mask, spix_colors, output=True):
51
+ cache_dir = '/apdcephfs/private_richardxia'
52
+ if output:
53
+ print('--- data saving')
54
+ mask_imgs = tensor2array(hint_mask) * 2.0 - 1.0
55
+ util.save_images_from_batch(mask_imgs, cache_dir, ['mask.png'], -1)
56
+ fake_gray = torch.zeros_like(spix_colors[:,[0],:,:])
57
+ spix_labs = torch.cat((fake_gray,spix_colors), dim=1)
58
+ spix_imgs = tensor2array(spix_labs)
59
+ util.save_normLabs_from_batch(spix_imgs, cache_dir, ['color.png'], -1)
60
+ return hint_mask, spix_colors
61
+ else:
62
+ print('--- data loading')
63
+ mask_img = cv2.imread(cache_dir+'/mask.png', cv2.IMREAD_GRAYSCALE)
64
+ mask_img = np.expand_dims(mask_img, axis=2) / 255.
65
+ hint_mask = torch.from_numpy(mask_img.transpose((2, 0, 1)))
66
+ hint_mask = hint_mask.unsqueeze(0).cuda()
67
+ bgr_img = cv2.imread(cache_dir+'/color.png', cv2.IMREAD_COLOR)
68
+ rgb_img = cv2.cvtColor(bgr_img, cv2.COLOR_BGR2RGB)
69
+ rgb_img = np.array(rgb_img / 255., np.float32)
70
+ lab_img = cv2.cvtColor(rgb_img, cv2.COLOR_RGB2LAB)
71
+ lab_img = torch.from_numpy(lab_img.transpose((2, 0, 1)))
72
+ ab_chans = lab_img[1:3,:,:] / 110.
73
+ spix_colors = ab_chans.unsqueeze(0).cuda()
74
+ return hint_mask.float(), spix_colors.float()
75
+
76
+
77
+ class Quantize(Function):
78
+ @staticmethod
79
+ def forward(ctx, x):
80
+ ctx.save_for_backward(x)
81
+ y = x.round()
82
+ return y
83
+
84
+ @staticmethod
85
+ def backward(ctx, grad_output):
86
+ """
87
+ In the backward pass we receive a Tensor containing the gradient of the loss
88
+ with respect to the output, and we need to compute the gradient of the loss
89
+ with respect to the input.
90
+ """
91
+ inputX = ctx.saved_tensors
92
+ return grad_output
93
+
94
+
95
+ def mark_color_hints(input_grays, target_ABs, gate_maps, kernel_size=3, base_ABs=None):
96
+ ## to highlight the seeds with 1-pixel margin
97
+ binary_map = torch.where(gate_maps>0.7, torch.ones_like(gate_maps), torch.zeros_like(gate_maps))
98
+ center_mask = dilate_seeds(binary_map, kernel_size=kernel_size)
99
+ margin_mask = dilate_seeds(binary_map, kernel_size=kernel_size+2) - center_mask
100
+ ## drop colors
101
+ dilated_seeds = dilate_seeds(gate_maps, kernel_size=kernel_size+2)
102
+ marked_grays = torch.where(margin_mask > 1e-5, torch.ones_like(gate_maps), input_grays)
103
+ if base_ABs is None:
104
+ marked_ABs = torch.where(center_mask < 1e-5, torch.zeros_like(target_ABs), target_ABs)
105
+ else:
106
+ marked_ABs = torch.where(margin_mask > 1e-5, torch.zeros_like(base_ABs), base_ABs)
107
+ marked_ABs = torch.where(center_mask > 1e-5, target_ABs, marked_ABs)
108
+ return torch.cat((marked_grays,marked_ABs), dim=1)
109
+
110
+ def dilate_seeds(gate_maps, kernel_size=3):
111
+ N,C,H,W = gate_maps.shape
112
+ input_unf = F.unfold(gate_maps, kernel_size, padding=kernel_size//2)
113
+ #! Notice: differentiable? just like max pooling?
114
+ dilated_seeds, _ = torch.max(input_unf, dim=1, keepdim=True)
115
+ output = F.fold(dilated_seeds, output_size=(H,W), kernel_size=1)
116
+ #print('-------', input_unf.shape)
117
+ return output
118
+
119
+
120
+ class RebalanceLoss(Function):
121
+ @staticmethod
122
+ def forward(ctx, data_input, weights):
123
+ ctx.save_for_backward(weights)
124
+ return data_input.clone()
125
+
126
+ @staticmethod
127
+ def backward(ctx, grad_output):
128
+ weights, = ctx.saved_tensors
129
+ # reweigh gradient pixelwise so that rare colors get a chance to
130
+ # contribute
131
+ grad_input = grad_output * weights
132
+ # second return value is None since we are not interested in the
133
+ # gradient with respect to the weights
134
+ return grad_input, None
135
+
136
+
137
+ class GetClassWeights:
138
+ def __init__(self, cielab, lambda_=0.5, device='cuda'):
139
+ prior = torch.from_numpy(cielab.gamut.prior).cuda()
140
+ uniform = torch.zeros_like(prior)
141
+ uniform[prior > 0] = 1 / (prior > 0).sum().type_as(uniform)
142
+ self.weights = 1 / ((1 - lambda_) * prior + lambda_ * uniform)
143
+ self.weights /= torch.sum(prior * self.weights)
144
+
145
+ def __call__(self, ab_actual):
146
+ return self.weights[ab_actual.argmax(dim=1, keepdim=True)]
147
+
148
+
149
+ class ColorLabel:
150
+ def __init__(self, lambda_=0.5, device='cuda'):
151
+ self.cielab = cielab.CIELAB()
152
+ self.q_to_ab = torch.from_numpy(self.cielab.q_to_ab).to(device)
153
+ prior = torch.from_numpy(self.cielab.gamut.prior).to(device)
154
+ uniform = torch.zeros_like(prior)
155
+ uniform[prior>0] = 1 / (prior>0).sum().type_as(uniform)
156
+ self.weights = 1 / ((1-lambda_) * prior + lambda_ * uniform)
157
+ self.weights /= torch.sum(prior * self.weights)
158
+
159
+ def visualize_label(self, step=3):
160
+ height, width = 200, 313*step
161
+ label_lab = np.ones((height,width,3), np.float32)
162
+ for x in range(313):
163
+ ab = self.cielab.q_to_ab[x,:]
164
+ label_lab[:,step*x:step*(x+1),1:] = ab / 110.
165
+ label_lab[:,:,0] = np.zeros((height,width), np.float32)
166
+ return label_lab
167
+
168
+ @staticmethod
169
+ def _gauss_eval(x, mu, sigma):
170
+ norm = 1 / (2 * math.pi * sigma)
171
+ return norm * torch.exp(-torch.sum((x - mu)**2, dim=0) / (2 * sigma**2))
172
+
173
+ def get_classweights(self, batch_gt_indx):
174
+ #return self.weights[batch_gt_q.argmax(dim=1, keepdim=True)]
175
+ return self.weights[batch_gt_indx]
176
+
177
+ def encode_ab2ind(self, batch_ab, neighbours=5, sigma=5.0):
178
+ batch_ab = batch_ab * 110.
179
+ n, _, h, w = batch_ab.shape
180
+ m = n * h * w
181
+ # find nearest neighbours
182
+ ab_ = batch_ab.permute(1, 0, 2, 3).reshape(2, -1) # (2, n*h*w)
183
+ cdist = torch.cdist(self.q_to_ab, ab_.t())
184
+ nns = cdist.argsort(dim=0)[:neighbours, :]
185
+ # gaussian weighting
186
+ nn_gauss = batch_ab.new_zeros(neighbours, m)
187
+ for i in range(neighbours):
188
+ nn_gauss[i, :] = self._gauss_eval(self.q_to_ab[nns[i, :], :].t(), ab_, sigma)
189
+ nn_gauss /= nn_gauss.sum(dim=0, keepdim=True)
190
+ # expand
191
+ bins = self.cielab.gamut.EXPECTED_SIZE
192
+ q = batch_ab.new_zeros(bins, m)
193
+ q[nns, torch.arange(m).repeat(neighbours, 1)] = nn_gauss
194
+ return q.reshape(bins, n, h, w).permute(1, 0, 2, 3)
195
+
196
+ def decode_ind2ab(self, batch_q, T=0.38):
197
+ _, _, h, w = batch_q.shape
198
+ batch_q = F.softmax(batch_q, dim=1)
199
+ if T%1 == 0:
200
+ # take the T-st probable index
201
+ sorted_probs, batch_indexs = torch.sort(batch_q, dim=1, descending=True)
202
+ #print('checking [index]', batch_indexs[:,0:5,5,5])
203
+ #print('checking [probs]', sorted_probs[:,0:5,5,5])
204
+ batch_indexs = batch_indexs[:,T:T+1,:,:]
205
+ #batch_indexs = torch.where(sorted_probs[:,T:T+1,:,:] > 0.25, batch_indexs[:,T:T+1,:,:], batch_indexs[:,0:1,:,:])
206
+ ab = torch.stack([
207
+ self.q_to_ab.index_select(0, q_i.flatten()).reshape(h,w,2).permute(2,0,1)
208
+ for q_i in batch_indexs])
209
+ else:
210
+ batch_q = torch.exp(batch_q / T)
211
+ batch_q /= batch_q.sum(dim=1, keepdim=True)
212
+ a = torch.tensordot(batch_q, self.q_to_ab[:,0], dims=((1,), (0,)))
213
+ a = a.unsqueeze(dim=1)
214
+ b = torch.tensordot(batch_q, self.q_to_ab[:,1], dims=((1,), (0,)))
215
+ b = b.unsqueeze(dim=1)
216
+ ab = torch.cat((a, b), dim=1)
217
+ ab = ab / 110.
218
+ return ab.type(batch_q.dtype)
219
+
220
+
221
+ def init_spixel_grid(img_height, img_width, spixel_size=16):
222
+ # get spixel id for the final assignment
223
+ n_spixl_h = int(np.floor(img_height/spixel_size))
224
+ n_spixl_w = int(np.floor(img_width/spixel_size))
225
+ spixel_height = int(img_height / (1. * n_spixl_h))
226
+ spixel_width = int(img_width / (1. * n_spixl_w))
227
+ spix_values = np.int32(np.arange(0, n_spixl_w * n_spixl_h).reshape((n_spixl_h, n_spixl_w)))
228
+
229
+ def shift9pos(input, h_shift_unit=1, w_shift_unit=1):
230
+ # input should be padding as (c, 1+ height+1, 1+width+1)
231
+ input_pd = np.pad(input, ((h_shift_unit, h_shift_unit), (w_shift_unit, w_shift_unit)), mode='edge')
232
+ input_pd = np.expand_dims(input_pd, axis=0)
233
+ # assign to ...
234
+ top = input_pd[:, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
235
+ bottom = input_pd[:, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
236
+ left = input_pd[:, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
237
+ right = input_pd[:, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
238
+ center = input_pd[:,h_shift_unit:-h_shift_unit,w_shift_unit:-w_shift_unit]
239
+ bottom_right = input_pd[:, 2 * h_shift_unit:, 2 * w_shift_unit:]
240
+ bottom_left = input_pd[:, 2 * h_shift_unit:, :-2 * w_shift_unit]
241
+ top_right = input_pd[:, :-2 * h_shift_unit, 2 * w_shift_unit:]
242
+ top_left = input_pd[:, :-2 * h_shift_unit, :-2 * w_shift_unit]
243
+ shift_tensor = np.concatenate([ top_left, top, top_right,
244
+ left, center, right,
245
+ bottom_left, bottom, bottom_right], axis=0)
246
+ return shift_tensor
247
+
248
+ spix_idx_tensor_ = shift9pos(spix_values)
249
+ spix_idx_tensor = np.repeat(
250
+ np.repeat(spix_idx_tensor_, spixel_height, axis=1), spixel_width, axis=2)
251
+ spixel_id_tensor = torch.from_numpy(spix_idx_tensor).type(torch.float)
252
+
253
+ #! pixel coord feature maps
254
+ all_h_coords = np.arange(0, img_height, 1)
255
+ all_w_coords = np.arange(0, img_width, 1)
256
+ curr_pxl_coord = np.array(np.meshgrid(all_h_coords, all_w_coords, indexing='ij'))
257
+ coord_feat_tensor = np.concatenate([curr_pxl_coord[1:2, :, :], curr_pxl_coord[:1, :, :]])
258
+ coord_feat_tensor = torch.from_numpy(coord_feat_tensor).type(torch.float)
259
+
260
+ return spixel_id_tensor, coord_feat_tensor
261
+
262
+
263
+ def split_spixels(assign_map, spixel_ids):
264
+ N,C,H,W = assign_map.shape
265
+ spixel_id_map = spixel_ids.expand(N,-1,-1,-1)
266
+ assig_max,_ = torch.max(assign_map, dim=1, keepdim=True)
267
+ assignment_ = torch.where(assign_map == assig_max, torch.ones(assign_map.shape).cuda(),torch.zeros(assign_map.shape).cuda())
268
+ ## winner take all
269
+ new_spixl_map_ = spixel_id_map * assignment_
270
+ new_spixl_map = torch.sum(new_spixl_map_,dim=1,keepdim=True).type(torch.int)
271
+ return new_spixl_map
272
+
273
+
274
+ def poolfeat(input, prob, sp_h=2, sp_w=2, need_entry_prob=False):
275
+ def feat_prob_sum(feat_sum, prob_sum, shift_feat):
276
+ feat_sum += shift_feat[:, :-1, :, :]
277
+ prob_sum += shift_feat[:, -1:, :, :]
278
+ return feat_sum, prob_sum
279
+
280
+ b, _, h, w = input.shape
281
+ h_shift_unit = 1
282
+ w_shift_unit = 1
283
+ p2d = (w_shift_unit, w_shift_unit, h_shift_unit, h_shift_unit)
284
+ feat_ = torch.cat([input, torch.ones([b, 1, h, w], device=input.device)], dim=1) # b* (n+1) *h*w
285
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 0, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
286
+ send_to_top_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, 2 * w_shift_unit:]
287
+ feat_sum = send_to_top_left[:, :-1, :, :].clone()
288
+ prob_sum = send_to_top_left[:, -1:, :, :].clone()
289
+
290
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 1, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
291
+ top = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, w_shift_unit:-w_shift_unit]
292
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top)
293
+
294
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 2, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
295
+ top_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, 2 * h_shift_unit:, :-2 * w_shift_unit]
296
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, top_right)
297
+
298
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 3, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
299
+ left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, 2 * w_shift_unit:]
300
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, left)
301
+
302
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 4, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
303
+ center = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, w_shift_unit:-w_shift_unit]
304
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, center)
305
+
306
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 5, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
307
+ right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, h_shift_unit:-h_shift_unit, :-2 * w_shift_unit]
308
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, right)
309
+
310
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 6, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
311
+ bottom_left = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, 2 * w_shift_unit:]
312
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_left)
313
+
314
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 7, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
315
+ bottom = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, w_shift_unit:-w_shift_unit]
316
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom)
317
+
318
+ prob_feat = F.avg_pool2d(feat_ * prob.narrow(1, 8, 1), kernel_size=(sp_h, sp_w), stride=(sp_h, sp_w)) # b * (n+1) * h* w
319
+ bottom_right = F.pad(prob_feat, p2d, mode='constant', value=0)[:, :, :-2 * h_shift_unit, :-2 * w_shift_unit]
320
+ feat_sum, prob_sum = feat_prob_sum(feat_sum, prob_sum, bottom_right)
321
+ pooled_feat = feat_sum / (prob_sum + 1e-8)
322
+ if need_entry_prob:
323
+ return pooled_feat, prob_sum
324
+ return pooled_feat
325
+
326
+
327
+ def get_spixel_size(affinity_map, sp_h=2, sp_w=2, elem_thres=25):
328
+ N,C,H,W = affinity_map.shape
329
+ device = affinity_map.device
330
+ assign_max,_ = torch.max(affinity_map, dim=1, keepdim=True)
331
+ assign_map = torch.where(affinity_map==assign_max, torch.ones(affinity_map.shape, device=device), torch.zeros(affinity_map.shape, device=device))
332
+ ## one_map = (N,1,H,W)
333
+ _, elem_num_maps = poolfeat(torch.ones(assign_max.shape, device=device), assign_map, sp_h, sp_w, True)
334
+ #all_one_map = torch.ones(elem_num_maps.shape).cuda()
335
+ #empty_mask = torch.where(elem_num_maps < elem_thres/256, all_one_map, 1-all_one_map)
336
+ return elem_num_maps
337
+
338
+
339
+ def upfeat(input, prob, up_h=2, up_w=2):
340
+ # input b*n*H*W downsampled
341
+ # prob b*9*h*w
342
+ b, c, h, w = input.shape
343
+
344
+ h_shift = 1
345
+ w_shift = 1
346
+
347
+ p2d = (w_shift, w_shift, h_shift, h_shift)
348
+ feat_pd = F.pad(input, p2d, mode='constant', value=0)
349
+
350
+ gt_frm_top_left = F.interpolate(feat_pd[:, :, :-2 * h_shift, :-2 * w_shift], size=(h * up_h, w * up_w),mode='nearest')
351
+ feat_sum = gt_frm_top_left * prob.narrow(1,0,1)
352
+
353
+ top = F.interpolate(feat_pd[:, :, :-2 * h_shift, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
354
+ feat_sum += top * prob.narrow(1, 1, 1)
355
+
356
+ top_right = F.interpolate(feat_pd[:, :, :-2 * h_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
357
+ feat_sum += top_right * prob.narrow(1,2,1)
358
+
359
+ left = F.interpolate(feat_pd[:, :, h_shift:-w_shift, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
360
+ feat_sum += left * prob.narrow(1, 3, 1)
361
+
362
+ center = F.interpolate(input, (h * up_h, w * up_w), mode='nearest')
363
+ feat_sum += center * prob.narrow(1, 4, 1)
364
+
365
+ right = F.interpolate(feat_pd[:, :, h_shift:-w_shift, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
366
+ feat_sum += right * prob.narrow(1, 5, 1)
367
+
368
+ bottom_left = F.interpolate(feat_pd[:, :, 2 * h_shift:, :-2 * w_shift], size=(h * up_h, w * up_w), mode='nearest')
369
+ feat_sum += bottom_left * prob.narrow(1, 6, 1)
370
+
371
+ bottom = F.interpolate(feat_pd[:, :, 2 * h_shift:, w_shift:-w_shift], size=(h * up_h, w * up_w), mode='nearest')
372
+ feat_sum += bottom * prob.narrow(1, 7, 1)
373
+
374
+ bottom_right = F.interpolate(feat_pd[:, :, 2 * h_shift:, 2 * w_shift:], size=(h * up_h, w * up_w), mode='nearest')
375
+ feat_sum += bottom_right * prob.narrow(1, 8, 1)
376
+
377
+ return feat_sum
378
+
379
+
380
+ def suck_and_spread(self, base_maps, seg_layers):
381
+ N,S,H,W = seg_layers.shape
382
+ base_maps = base_maps.unsqueeze(1)
383
+ seg_layers = seg_layers.unsqueeze(2)
384
+ ## (N,S,C,1,1) = (N,1,C,H,W) * (N,S,1,H,W)
385
+ mean_val_layers = (base_maps * seg_layers).sum(dim=(3,4), keepdim=True) / (1e-5 + seg_layers.sum(dim=(3,4), keepdim=True))
386
+ ## normalized to be sum one
387
+ weight_layers = seg_layers / (1e-5 + torch.sum(seg_layers, dim=1, keepdim=True))
388
+ ## (N,S,C,H,W) = (N,S,C,1,1) * (N,S,1,H,W)
389
+ recon_maps = mean_val_layers * weight_layers
390
+ return recon_maps.sum(dim=1)
391
+
392
+
393
+ #! copy from Richard Zhang [SIGGRAPH2017]
394
+ # RGB grid points maps to Lab range: L[0,100], a[-86.183,98,233], b[-107.857,94.478]
395
+ #------------------------------------------------------------------------------
396
+ def rgb2xyz(rgb): # rgb from [0,1]
397
+ # xyz_from_rgb = np.array([[0.412453, 0.357580, 0.180423],
398
+ # [0.212671, 0.715160, 0.072169],
399
+ # [0.019334, 0.119193, 0.950227]])
400
+ mask = (rgb > .04045).type(torch.FloatTensor)
401
+ if(rgb.is_cuda):
402
+ mask = mask.cuda()
403
+ rgb = (((rgb+.055)/1.055)**2.4)*mask + rgb/12.92*(1-mask)
404
+ x = .412453*rgb[:,0,:,:]+.357580*rgb[:,1,:,:]+.180423*rgb[:,2,:,:]
405
+ y = .212671*rgb[:,0,:,:]+.715160*rgb[:,1,:,:]+.072169*rgb[:,2,:,:]
406
+ z = .019334*rgb[:,0,:,:]+.119193*rgb[:,1,:,:]+.950227*rgb[:,2,:,:]
407
+ out = torch.cat((x[:,None,:,:],y[:,None,:,:],z[:,None,:,:]),dim=1)
408
+ return out
409
+
410
+ def xyz2rgb(xyz):
411
+ # array([[ 3.24048134, -1.53715152, -0.49853633],
412
+ # [-0.96925495, 1.87599 , 0.04155593],
413
+ # [ 0.05564664, -0.20404134, 1.05731107]])
414
+ r = 3.24048134*xyz[:,0,:,:]-1.53715152*xyz[:,1,:,:]-0.49853633*xyz[:,2,:,:]
415
+ g = -0.96925495*xyz[:,0,:,:]+1.87599*xyz[:,1,:,:]+.04155593*xyz[:,2,:,:]
416
+ b = .05564664*xyz[:,0,:,:]-.20404134*xyz[:,1,:,:]+1.05731107*xyz[:,2,:,:]
417
+ rgb = torch.cat((r[:,None,:,:],g[:,None,:,:],b[:,None,:,:]),dim=1)
418
+ #! sometimes reaches a small negative number, which causes NaNs
419
+ rgb = torch.max(rgb,torch.zeros_like(rgb))
420
+ mask = (rgb > .0031308).type(torch.FloatTensor)
421
+ if(rgb.is_cuda):
422
+ mask = mask.cuda()
423
+ rgb = (1.055*(rgb**(1./2.4)) - 0.055)*mask + 12.92*rgb*(1-mask)
424
+ return rgb
425
+
426
+ def xyz2lab(xyz):
427
+ # 0.95047, 1., 1.08883 # white
428
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
429
+ if(xyz.is_cuda):
430
+ sc = sc.cuda()
431
+ xyz_scale = xyz/sc
432
+ mask = (xyz_scale > .008856).type(torch.FloatTensor)
433
+ if(xyz_scale.is_cuda):
434
+ mask = mask.cuda()
435
+ xyz_int = xyz_scale**(1/3.)*mask + (7.787*xyz_scale + 16./116.)*(1-mask)
436
+ L = 116.*xyz_int[:,1,:,:]-16.
437
+ a = 500.*(xyz_int[:,0,:,:]-xyz_int[:,1,:,:])
438
+ b = 200.*(xyz_int[:,1,:,:]-xyz_int[:,2,:,:])
439
+ out = torch.cat((L[:,None,:,:],a[:,None,:,:],b[:,None,:,:]),dim=1)
440
+ return out
441
+
442
+ def lab2xyz(lab):
443
+ y_int = (lab[:,0,:,:]+16.)/116.
444
+ x_int = (lab[:,1,:,:]/500.) + y_int
445
+ z_int = y_int - (lab[:,2,:,:]/200.)
446
+ if(z_int.is_cuda):
447
+ z_int = torch.max(torch.Tensor((0,)).cuda(), z_int)
448
+ else:
449
+ z_int = torch.max(torch.Tensor((0,)), z_int)
450
+ out = torch.cat((x_int[:,None,:,:],y_int[:,None,:,:],z_int[:,None,:,:]),dim=1)
451
+ mask = (out > .2068966).type(torch.FloatTensor)
452
+ if(out.is_cuda):
453
+ mask = mask.cuda()
454
+ out = (out**3.)*mask + (out - 16./116.)/7.787*(1-mask)
455
+ sc = torch.Tensor((0.95047, 1., 1.08883))[None,:,None,None]
456
+ sc = sc.to(out.device)
457
+ out = out*sc
458
+ return out
459
+
460
+ def rgb2lab(rgb, l_mean=50, l_norm=50, ab_norm=110):
461
+ #! input rgb: [0,1]
462
+ #! output lab: [-1,1]
463
+ lab = xyz2lab(rgb2xyz(rgb))
464
+ l_rs = (lab[:,[0],:,:]-l_mean) / l_norm
465
+ ab_rs = lab[:,1:,:,:] / ab_norm
466
+ out = torch.cat((l_rs,ab_rs),dim=1)
467
+ return out
468
+
469
+ def lab2rgb(lab_rs, l_mean=50, l_norm=50, ab_norm=110):
470
+ #! input lab: [-1,1]
471
+ #! output rgb: [0,1]
472
+ l_ = lab_rs[:,[0],:,:] * l_norm + l_mean
473
+ ab = lab_rs[:,1:,:,:] * ab_norm
474
+ lab = torch.cat((l_,ab), dim=1)
475
+ out = xyz2rgb(lab2xyz(lab))
476
+ return out
477
+
478
+
479
+ if __name__ == '__main__':
480
+ minL, minA, minB = 999., 999., 999.
481
+ maxL, maxA, maxB = 0., 0., 0.
482
+ for r in range(256):
483
+ print('h',r)
484
+ for g in range(256):
485
+ for b in range(256):
486
+ rgb = np.array([r,g,b], np.float32).reshape(1,1,-1) / 255.0
487
+ #lab_img = cv2.cvtColor(rgb, cv2.COLOR_RGB2LAB)
488
+ rgb = torch.from_numpy(rgb.transpose((2, 0, 1)))
489
+ rgb = rgb.reshape(1,3,1,1)
490
+ lab = rgb2lab(rgb)
491
+ lab[:,[0],:,:] = lab[:,[0],:,:] * 50 + 50
492
+ lab[:,1:,:,:] = lab[:,1:,:,:] * 110
493
+ lab = lab.squeeze()
494
+ lab_float = lab.numpy()
495
+ #print('zhang vs. cv2:', lab_float, lab_img.squeeze())
496
+ minL = min(lab_float[0], minL)
497
+ minA = min(lab_float[1], minA)
498
+ minB = min(lab_float[2], minB)
499
+ maxL = max(lab_float[0], maxL)
500
+ maxA = max(lab_float[1], maxA)
501
+ maxB = max(lab_float[2], maxB)
502
+ print('L:', minL, maxL)
503
+ print('A:', minA, maxA)
504
+ print('B:', minB, maxB)
models/clusterkit.py ADDED
@@ -0,0 +1,291 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ from functools import partial
6
+ import numpy as np
7
+ import torch
8
+ from tqdm import tqdm
9
+ import math, random
10
+ #from sklearn.cluster import KMeans, kmeans_plusplus, MeanShift, estimate_bandwidth
11
+
12
+
13
+ def tensor_kmeans_sklearn(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
14
+ N,C,H,W = data_vecs.shape
15
+ assert N == 1, 'only support singe image tensor'
16
+ ## (1,C,H,W) -> (HW,C)
17
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
18
+ ## convert tensor to array
19
+ data_vecs_np = data_vecs.squeeze().detach().to("cpu").numpy()
20
+ km = KMeans(n_clusters=n_clusters, init='k-means++', n_init=10, max_iter=300)
21
+ pred = km.fit_predict(data_vecs_np)
22
+ cluster_ids_x = torch.from_numpy(km.labels_).to(data_vecs.device)
23
+ id_maps = cluster_ids_x.reshape(1,1,H,W).long()
24
+ if need_layer_masks:
25
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
26
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
27
+ return cluster_mask
28
+ return id_maps
29
+
30
+
31
+ def tensor_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', need_layer_masks=False, max_iters=20):
32
+ N,C,H,W = data_vecs.shape
33
+ assert N == 1, 'only support singe image tensor'
34
+
35
+ ## (1,C,H,W) -> (HW,C)
36
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
37
+ ## cosine | euclidean
38
+ #cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric, device=data_vecs.device)
39
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
40
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
41
+ id_maps = cluster_ids_x.reshape(1,1,H,W)
42
+ if need_layer_masks:
43
+ one_hot_labels = F.one_hot(id_maps.squeeze(1), num_classes=n_clusters).float()
44
+ cluster_mask = one_hot_labels.permute(0,3,1,2)
45
+ return cluster_mask
46
+ return id_maps
47
+
48
+
49
+ def batch_kmeans_pytorch(data_vecs, n_clusters=7, metric='euclidean', use_sklearn_kmeans=False):
50
+ N,C,H,W = data_vecs.shape
51
+ sample_list = []
52
+ for idx in range(N):
53
+ if use_sklearn_kmeans:
54
+ cluster_mask = tensor_kmeans_sklearn(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
55
+ else:
56
+ cluster_mask = tensor_kmeans_pytorch(data_vecs[idx:idx+1,:,:,:], n_clusters, metric, True)
57
+ sample_list.append(cluster_mask)
58
+ return torch.cat(sample_list, dim=0)
59
+
60
+
61
+ def get_centroid_candidates(data_vecs, n_clusters=7, metric='euclidean', max_iters=20):
62
+ N,C,H,W = data_vecs.shape
63
+ data_vecs = data_vecs.permute(0,2,3,1).view(-1,C)
64
+ cluster_ids_x, cluster_centers = kmeans(X=data_vecs, num_clusters=n_clusters, distance=metric,\
65
+ tqdm_flag=False, iter_limit=max_iters, device=data_vecs.device)
66
+ return cluster_centers
67
+
68
+
69
+ def find_distinctive_elements(data_tensor, n_clusters=7, topk=3, metric='euclidean'):
70
+ N,C,H,W = data_tensor.shape
71
+ centroid_list = []
72
+ for idx in range(N):
73
+ cluster_centers = get_centroid_candidates(data_tensor[idx:idx+1,:,:,:], n_clusters, metric)
74
+ centroid_list.append(cluster_centers)
75
+
76
+ batch_centroids = torch.stack(centroid_list, dim=0)
77
+ data_vecs = data_tensor.flatten(2)
78
+ ## distance matrix: (N,K,HW) = (N,K,C) x (N,C,HW)
79
+ AtB = torch.matmul(batch_centroids, data_vecs)
80
+ AtA = torch.matmul(batch_centroids, batch_centroids.permute(0,2,1))
81
+ BtB = torch.matmul(data_vecs.permute(0,2,1), data_vecs)
82
+ diag_A = torch.diagonal(AtA, dim1=-2, dim2=-1)
83
+ diag_B = torch.diagonal(BtB, dim1=-2, dim2=-1)
84
+ A2 = diag_A.unsqueeze(2).repeat(1,1,H*W)
85
+ B2 = diag_B.unsqueeze(1).repeat(1,n_clusters,1)
86
+ distance_map = A2 - 2*AtB + B2
87
+ values, indices = distance_map.topk(topk, dim=2, largest=False, sorted=True)
88
+ cluster_mask = torch.where(distance_map <= values[:,:,topk-1:], torch.ones_like(distance_map), torch.zeros_like(distance_map))
89
+ cluster_mask = cluster_mask.view(N,n_clusters,H,W)
90
+ return cluster_mask
91
+
92
+
93
+ ##---------------------------------------------------------------------------------
94
+ '''
95
+ resource from github: https://github.com/subhadarship/kmeans_pytorch
96
+ '''
97
+ ##---------------------------------------------------------------------------------
98
+
99
+ def initialize(X, num_clusters):
100
+ """
101
+ initialize cluster centers
102
+ :param X: (torch.tensor) matrix
103
+ :param num_clusters: (int) number of clusters
104
+ :return: (np.array) initial state
105
+ """
106
+ np.random.seed(1)
107
+ num_samples = len(X)
108
+ indices = np.random.choice(num_samples, num_clusters, replace=False)
109
+ initial_state = X[indices]
110
+ return initial_state
111
+
112
+
113
+ def kmeans(
114
+ X,
115
+ num_clusters,
116
+ distance='euclidean',
117
+ cluster_centers=[],
118
+ tol=1e-4,
119
+ tqdm_flag=True,
120
+ iter_limit=0,
121
+ device=torch.device('cpu'),
122
+ gamma_for_soft_dtw=0.001
123
+ ):
124
+ """
125
+ perform kmeans
126
+ :param X: (torch.tensor) matrix
127
+ :param num_clusters: (int) number of clusters
128
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
129
+ :param tol: (float) threshold [default: 0.0001]
130
+ :param device: (torch.device) device [default: cpu]
131
+ :param tqdm_flag: Allows to turn logs on and off
132
+ :param iter_limit: hard limit for max number of iterations
133
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
134
+ :return: (torch.tensor, torch.tensor) cluster ids, cluster centers
135
+ """
136
+ if tqdm_flag:
137
+ print(f'running k-means on {device}..')
138
+
139
+ if distance == 'euclidean':
140
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
141
+ elif distance == 'cosine':
142
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
143
+ else:
144
+ raise NotImplementedError
145
+
146
+ # convert to float
147
+ X = X.float()
148
+
149
+ # transfer to device
150
+ X = X.to(device)
151
+
152
+ # initialize
153
+ if type(cluster_centers) == list: # ToDo: make this less annoyingly weird
154
+ initial_state = initialize(X, num_clusters)
155
+ else:
156
+ if tqdm_flag:
157
+ print('resuming')
158
+ # find data point closest to the initial cluster center
159
+ initial_state = cluster_centers
160
+ dis = pairwise_distance_function(X, initial_state)
161
+ choice_points = torch.argmin(dis, dim=0)
162
+ initial_state = X[choice_points]
163
+ initial_state = initial_state.to(device)
164
+
165
+ iteration = 0
166
+ if tqdm_flag:
167
+ tqdm_meter = tqdm(desc='[running kmeans]')
168
+ while True:
169
+
170
+ dis = pairwise_distance_function(X, initial_state)
171
+
172
+ choice_cluster = torch.argmin(dis, dim=1)
173
+
174
+ initial_state_pre = initial_state.clone()
175
+
176
+ for index in range(num_clusters):
177
+ selected = torch.nonzero(choice_cluster == index).squeeze().to(device)
178
+
179
+ selected = torch.index_select(X, 0, selected)
180
+
181
+ # https://github.com/subhadarship/kmeans_pytorch/issues/16
182
+ if selected.shape[0] == 0:
183
+ selected = X[torch.randint(len(X), (1,))]
184
+
185
+ initial_state[index] = selected.mean(dim=0)
186
+
187
+ center_shift = torch.sum(
188
+ torch.sqrt(
189
+ torch.sum((initial_state - initial_state_pre) ** 2, dim=1)
190
+ ))
191
+
192
+ # increment iteration
193
+ iteration = iteration + 1
194
+
195
+ # update tqdm meter
196
+ if tqdm_flag:
197
+ tqdm_meter.set_postfix(
198
+ iteration=f'{iteration}',
199
+ center_shift=f'{center_shift ** 2:0.6f}',
200
+ tol=f'{tol:0.6f}'
201
+ )
202
+ tqdm_meter.update()
203
+ if center_shift ** 2 < tol:
204
+ break
205
+ if iter_limit != 0 and iteration >= iter_limit:
206
+ #print('hello, there!')
207
+ break
208
+
209
+ return choice_cluster.to(device), initial_state.to(device)
210
+
211
+
212
+ def kmeans_predict(
213
+ X,
214
+ cluster_centers,
215
+ distance='euclidean',
216
+ device=torch.device('cpu'),
217
+ gamma_for_soft_dtw=0.001,
218
+ tqdm_flag=True
219
+ ):
220
+ """
221
+ predict using cluster centers
222
+ :param X: (torch.tensor) matrix
223
+ :param cluster_centers: (torch.tensor) cluster centers
224
+ :param distance: (str) distance [options: 'euclidean', 'cosine'] [default: 'euclidean']
225
+ :param device: (torch.device) device [default: 'cpu']
226
+ :param gamma_for_soft_dtw: approaches to (hard) DTW as gamma -> 0
227
+ :return: (torch.tensor) cluster ids
228
+ """
229
+ if tqdm_flag:
230
+ print(f'predicting on {device}..')
231
+
232
+ if distance == 'euclidean':
233
+ pairwise_distance_function = partial(pairwise_distance, device=device, tqdm_flag=tqdm_flag)
234
+ elif distance == 'cosine':
235
+ pairwise_distance_function = partial(pairwise_cosine, device=device)
236
+ elif distance == 'soft_dtw':
237
+ sdtw = SoftDTW(use_cuda=device.type == 'cuda', gamma=gamma_for_soft_dtw)
238
+ pairwise_distance_function = partial(pairwise_soft_dtw, sdtw=sdtw, device=device)
239
+ else:
240
+ raise NotImplementedError
241
+
242
+ # convert to float
243
+ X = X.float()
244
+
245
+ # transfer to device
246
+ X = X.to(device)
247
+
248
+ dis = pairwise_distance_function(X, cluster_centers)
249
+ choice_cluster = torch.argmin(dis, dim=1)
250
+
251
+ return choice_cluster.cpu()
252
+
253
+
254
+ def pairwise_distance(data1, data2, device=torch.device('cpu'), tqdm_flag=True):
255
+ if tqdm_flag:
256
+ print(f'device is :{device}')
257
+
258
+ # transfer to device
259
+ data1, data2 = data1.to(device), data2.to(device)
260
+
261
+ # N*1*M
262
+ A = data1.unsqueeze(dim=1)
263
+
264
+ # 1*N*M
265
+ B = data2.unsqueeze(dim=0)
266
+
267
+ dis = (A - B) ** 2.0
268
+ # return N*N matrix for pairwise distance
269
+ dis = dis.sum(dim=-1).squeeze()
270
+ return dis
271
+
272
+
273
+ def pairwise_cosine(data1, data2, device=torch.device('cpu')):
274
+ # transfer to device
275
+ data1, data2 = data1.to(device), data2.to(device)
276
+
277
+ # N*1*M
278
+ A = data1.unsqueeze(dim=1)
279
+
280
+ # 1*N*M
281
+ B = data2.unsqueeze(dim=0)
282
+
283
+ # normalize the points | [0.3, 0.4] -> [0.3/sqrt(0.09 + 0.16), 0.4/sqrt(0.09 + 0.16)] = [0.3/0.5, 0.4/0.5]
284
+ A_normalized = A / A.norm(dim=-1, keepdim=True)
285
+ B_normalized = B / B.norm(dim=-1, keepdim=True)
286
+
287
+ cosine = A_normalized * B_normalized
288
+
289
+ # return N*N matrix for pairwise distance
290
+ cosine_dis = 1 - cosine.sum(dim=-1).squeeze()
291
+ return cosine_dis
models/loss.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import division
2
+ import os, glob, shutil, math, random, json
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torchvision
7
+ import basic
8
+ from utils import util
9
+
10
+ eps = 0.0000001
11
+
12
+ class SPixelLoss:
13
+ def __init__(self, psize=8, mpdist=False, gpu_no=0):
14
+ self.mpdist = mpdist
15
+ self.gpu_no = gpu_no
16
+ self.sp_size = psize
17
+
18
+ def __call__(self, data, epoch_no):
19
+ kernel_size = self.sp_size
20
+ #pos_weight = 0.003
21
+ prob = data['pred_prob']
22
+ labxy_feat = data['target_feat']
23
+ N,C,H,W = labxy_feat.shape
24
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
25
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
26
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
27
+ featLoss_idx = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
28
+ posLoss_idx = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() / kernel_size
29
+ totalLoss_idx = 10*featLoss_idx + 0.003*posLoss_idx
30
+ return {'totalLoss':totalLoss_idx, 'featLoss':featLoss_idx, 'posLoss':posLoss_idx}
31
+
32
+
33
+ class AnchorColorProbLoss:
34
+ def __init__(self, hint2regress=False, enhanced=False, with_grad=False, mpdist=False, gpu_no=0):
35
+ self.mpdist = mpdist
36
+ self.gpu_no = gpu_no
37
+ self.hint2regress = hint2regress
38
+ self.enhanced = enhanced
39
+ self.with_grad = with_grad
40
+ self.rebalance_gradient = basic.RebalanceLoss.apply
41
+ self.entropy_loss = nn.CrossEntropyLoss(ignore_index=-1)
42
+ if self.enhanced:
43
+ self.VGGLoss = VGG19Loss(gpu_no=gpu_no, is_ddp=mpdist)
44
+
45
+ def _perceptual_loss(self, input_grays, input_colors, pred_colors):
46
+ input_RGBs = basic.lab2rgb(torch.cat([input_grays,input_colors], dim=1))
47
+ pred_RGBs = basic.lab2rgb(torch.cat([input_grays,pred_colors], dim=1))
48
+ ## the output of "lab2rgb" just matches the input of "VGGLoss": [0,1]
49
+ return self.VGGLoss(input_RGBs, pred_RGBs)
50
+
51
+ def _laplace_gradient(self, pred_AB, target_AB):
52
+ N,C,H,W = pred_AB.shape
53
+ kernel = torch.tensor([[1, 1, 1], [1, -8, 1], [1, 1, 1]], device=pred_AB.get_device()).float()
54
+ kernel = kernel.view(1, 1, *kernel.size()).repeat(C,1,1,1)
55
+ grad_pred = F.conv2d(pred_AB, kernel, groups=C)
56
+ grad_trg = F.conv2d(target_AB, kernel, groups=C)
57
+ return l1_loss(grad_trg, grad_pred)
58
+
59
+ def __call__(self, data, epoch_no):
60
+ N,C,H,W = data['target_label'].shape
61
+ pal_probs = self.rebalance_gradient(data['pal_prob'], data['class_weight'])
62
+ #ref_probs = data['ref_prob']
63
+ pal_probs = pal_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
64
+ gt_labels = data['target_label'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
65
+ '''
66
+ igored_mask = data['empty_entries'].permute(0,2,3,1).contiguous().view(N*H*W, -1)
67
+ gt_labels[igored_mask] = -1
68
+ gt_labels = gt_probs.squeeze()
69
+ '''
70
+ palLoss_idx = self.entropy_loss(pal_probs, gt_labels.squeeze(dim=1))
71
+ if self.hint2regress:
72
+ ref_probs = data['ref_prob']
73
+ refLoss_idx = 50 * l2_loss(data['spix_color'], ref_probs)
74
+ else:
75
+ ref_probs = self.rebalance_gradient(data['ref_prob'], data['class_weight'])
76
+ ref_probs = ref_probs.permute(0,2,3,1).contiguous().view(N*H*W, -1)
77
+ refLoss_idx = self.entropy_loss(ref_probs, gt_labels.squeeze(dim=1))
78
+ reconLoss_idx = torch.zeros_like(palLoss_idx)
79
+ if self.enhanced:
80
+ scalar = 1.0 if self.hint2regress else 5.0
81
+ reconLoss_idx = scalar * self._perceptual_loss(data['input_gray'], data['pred_color'], data['input_color'])
82
+ if self.with_grad:
83
+ gradient_loss = self._laplace_gradient(data['pred_color'], data['input_color'])
84
+ reconLoss_idx += gradient_loss
85
+ totalLoss_idx = palLoss_idx + refLoss_idx + reconLoss_idx
86
+ #print("loss terms:", palLoss_idx.item(), refLoss_idx.item(), reconLoss_idx.item())
87
+ return {'totalLoss':totalLoss_idx, 'palLoss':palLoss_idx, 'refLoss':refLoss_idx, 'recLoss':reconLoss_idx}
88
+
89
+
90
+ def compute_affinity_pos_loss(prob_in, labxy_feat, pos_weight=0.003, kernel_size=16):
91
+ S = kernel_size
92
+ m = pos_weight
93
+ prob = prob_in.clone()
94
+ N,C,H,W = labxy_feat.shape
95
+ pooled_labxy = basic.poolfeat(labxy_feat, prob, kernel_size, kernel_size)
96
+ reconstr_feat = basic.upfeat(pooled_labxy, prob, kernel_size, kernel_size)
97
+ loss_map = reconstr_feat[:,:,:,:] - labxy_feat[:,:,:,:]
98
+ loss_feat = torch.norm(loss_map[:,:-2,:,:], p=2, dim=1).mean()
99
+ loss_pos = torch.norm(loss_map[:,-2:,:,:], p=2, dim=1).mean() * m / S
100
+ loss_affinity = loss_feat + loss_pos
101
+ return loss_affinity
102
+
103
+
104
+ def l2_loss(y_input, y_target, weight_map=None):
105
+ if weight_map is None:
106
+ return F.mse_loss(y_input, y_target)
107
+ else:
108
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
109
+ batch_dev = torch.sum(diff_map*diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
110
+ return batch_dev.mean()
111
+
112
+
113
+ def l1_loss(y_input, y_target, weight_map=None):
114
+ if weight_map is None:
115
+ return F.l1_loss(y_input, y_target)
116
+ else:
117
+ diff_map = torch.mean(torch.abs(y_input-y_target), dim=1, keepdim=True)
118
+ batch_dev = torch.sum(diff_map*weight_map, dim=(1,2,3)) / (eps+torch.sum(weight_map, dim=(1,2,3)))
119
+ return batch_dev.mean()
120
+
121
+
122
+ def masked_l1_loss(y_input, y_target, outlier_mask):
123
+ one = torch.tensor([1.0]).cuda(y_input.get_device())
124
+ weight_map = torch.where(outlier_mask, one * 0.0, one * 1.0)
125
+ return l1_loss(y_input, y_target, weight_map)
126
+
127
+
128
+ def huber_loss(y_input, y_target, delta=0.01):
129
+ mask = torch.zeros_like(y_input)
130
+ mann = torch.abs(y_input - y_target)
131
+ eucl = 0.5 * (mann**2)
132
+ mask[...] = mann < delta
133
+ loss = eucl * mask / delta + (mann - 0.5 * delta) * (1 - mask)
134
+ return torch.mean(loss)
135
+
136
+
137
+ ## Perceptual loss that uses a pretrained VGG network
138
+ class VGG19Loss(nn.Module):
139
+ def __init__(self, feat_type='liu', gpu_no=0, is_ddp=False, requires_grad=False):
140
+ super(VGG19Loss, self).__init__()
141
+ os.environ['TORCH_HOME'] = '/apdcephfs/share_1290939/richardxia/Saved/Checkpoints/VGG19'
142
+ ## data requirement: (N,C,H,W) in RGB format, [0,1] range, and resolution >= 224x224
143
+ self.mean = [0.485, 0.456, 0.406]
144
+ self.std = [0.229, 0.224, 0.225]
145
+ self.feat_type = feat_type
146
+
147
+ vgg_model = torchvision.models.vgg19(pretrained=True)
148
+ ## AssertionError: DistributedDataParallel is not needed when a module doesn't have any parameter that requires a gradient
149
+ '''
150
+ if is_ddp:
151
+ vgg_model = vgg_model.cuda(gpu_no)
152
+ vgg_model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(vgg_model)
153
+ vgg_model = torch.nn.parallel.DistributedDataParallel(vgg_model, device_ids=[gpu_no], find_unused_parameters=True)
154
+ else:
155
+ vgg_model = vgg_model.cuda(gpu_no)
156
+ '''
157
+ vgg_model = vgg_model.cuda(gpu_no)
158
+ if self.feat_type == 'liu':
159
+ ## conv1_1, conv2_1, conv3_1, conv4_1, conv5_1
160
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:2]).eval()
161
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[2:7]).eval()
162
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[7:12]).eval()
163
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[12:21]).eval()
164
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[21:30]).eval()
165
+ self.weights = [1.0/32, 1.0/16, 1.0/8, 1.0/4, 1.0]
166
+ elif self.feat_type == 'lei':
167
+ ## conv1_2, conv2_2, conv3_2, conv4_2, conv5_2
168
+ self.slice1 = nn.Sequential(*list(vgg_model.features)[:4]).eval()
169
+ self.slice2 = nn.Sequential(*list(vgg_model.features)[4:9]).eval()
170
+ self.slice3 = nn.Sequential(*list(vgg_model.features)[9:14]).eval()
171
+ self.slice4 = nn.Sequential(*list(vgg_model.features)[14:23]).eval()
172
+ self.slice5 = nn.Sequential(*list(vgg_model.features)[23:32]).eval()
173
+ self.weights = [1.0/2.6, 1.0/4.8, 1.0/3.7, 1.0/5.6, 10.0/1.5]
174
+ else:
175
+ ## maxpool after conv4_4
176
+ self.featureExactor = nn.Sequential(*list(vgg_model.features)[:28]).eval()
177
+ '''
178
+ for x in range(2):
179
+ self.slice1.add_module(str(x), pretrained_features[x])
180
+ for x in range(2, 7):
181
+ self.slice2.add_module(str(x), pretrained_features[x])
182
+ for x in range(7, 12):
183
+ self.slice3.add_module(str(x), pretrained_features[x])
184
+ for x in range(12, 21):
185
+ self.slice4.add_module(str(x), pretrained_features[x])
186
+ for x in range(21, 30):
187
+ self.slice5.add_module(str(x), pretrained_features[x])
188
+ '''
189
+ self.criterion = nn.L1Loss()
190
+
191
+ ## fixed parameters
192
+ if not requires_grad:
193
+ for param in self.parameters():
194
+ param.requires_grad = False
195
+ self.eval()
196
+ print('[*] VGG19Loss init!')
197
+
198
+ def normalize(self, tensor):
199
+ tensor = tensor.clone()
200
+ mean = torch.as_tensor(self.mean, dtype=torch.float32, device=tensor.device)
201
+ std = torch.as_tensor(self.std, dtype=torch.float32, device=tensor.device)
202
+ tensor.sub_(mean[None, :, None, None]).div_(std[None, :, None, None])
203
+ return tensor
204
+
205
+ def forward(self, x, y):
206
+ norm_x, norm_y = self.normalize(x), self.normalize(y)
207
+ ## feature extract
208
+ if self.feat_type == 'liu' or self.feat_type == 'lei':
209
+ x_relu1, y_relu1 = self.slice1(norm_x), self.slice1(norm_y)
210
+ x_relu2, y_relu2 = self.slice2(x_relu1), self.slice2(y_relu1)
211
+ x_relu3, y_relu3 = self.slice3(x_relu2), self.slice3(y_relu2)
212
+ x_relu4, y_relu4 = self.slice4(x_relu3), self.slice4(y_relu3)
213
+ x_relu5, y_relu5 = self.slice5(x_relu4), self.slice5(y_relu4)
214
+ x_vgg = [x_relu1, x_relu2, x_relu3, x_relu4, x_relu5]
215
+ y_vgg = [y_relu1, y_relu2, y_relu3, y_relu4, y_relu5]
216
+ loss = 0
217
+ for i in range(len(x_vgg)):
218
+ loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
219
+ else:
220
+ x_vgg, y_vgg = self.featureExactor(norm_x), self.featureExactor(norm_y)
221
+ loss = self.criterion(x_vgg, y_vgg.detach())
222
+ return loss
models/model.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from models.network import HourGlass2, SpixelNet, ColorProbNet
5
+ from models.transformer2d import EncoderLayer, DecoderLayer, TransformerEncoder, TransformerDecoder
6
+ from models.position_encoding import build_position_encoding
7
+ from models import basic, clusterkit, anchor_gen
8
+ from collections import OrderedDict
9
+ from utils import util, cielab
10
+
11
+
12
+ class SpixelSeg(nn.Module):
13
+ def __init__(self, inChannel=1, outChannel=9, batchNorm=True):
14
+ super(SpixelSeg, self).__init__()
15
+ self.net = SpixelNet(inChannel=inChannel, outChannel=outChannel, batchNorm=batchNorm)
16
+
17
+ def get_trainable_params(self, lr=1.0):
18
+ #print('=> [optimizer] finetune backbone with smaller lr')
19
+ params = []
20
+ for name, param in self.named_parameters():
21
+ if 'xxx' in name:
22
+ params.append({'params': param, 'lr': lr})
23
+ else:
24
+ params.append({'params': param})
25
+ return params
26
+
27
+ def forward(self, input_grays):
28
+ pred_probs = self.net(input_grays)
29
+ return pred_probs
30
+
31
+
32
+ class AnchorColorProb(nn.Module):
33
+ def __init__(self, inChannel=1, outChannel=313, sp_size=16, d_model=64, use_dense_pos=True, spix_pos=False, learning_pos=False, \
34
+ random_hint=False, hint2regress=False, enhanced=False, use_mask=False, rank=0, colorLabeler=None):
35
+ super(AnchorColorProb, self).__init__()
36
+ self.sp_size = sp_size
37
+ self.spix_pos = spix_pos
38
+ self.use_token_mask = use_mask
39
+ self.hint2regress = hint2regress
40
+ self.segnet = SpixelSeg(inChannel=1, outChannel=9, batchNorm=True)
41
+ self.repnet = ColorProbNet(inChannel=inChannel, outChannel=64)
42
+ self.enhanced = enhanced
43
+ if self.enhanced:
44
+ self.enhanceNet = HourGlass2(inChannel=64+1, outChannel=2, resNum=3, normLayer=nn.BatchNorm2d)
45
+
46
+ ## transformer architecture
47
+ self.n_vocab = 313
48
+ d_model, dim_feedforward, nhead = d_model, 4*d_model, 8
49
+ dropout, activation = 0.1, "relu"
50
+ n_enc_layers, n_dec_layers = 6, 6
51
+ enc_layer = EncoderLayer(d_model, nhead, dim_feedforward, dropout, activation, use_dense_pos)
52
+ self.wildpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
53
+ self.hintpath = TransformerEncoder(enc_layer, n_enc_layers, use_dense_pos)
54
+ if self.spix_pos:
55
+ n_pos_x, n_pos_y = 256, 256
56
+ else:
57
+ n_pos_x, n_pos_y = 256//sp_size, 16//sp_size
58
+ self.pos_enc = build_position_encoding(d_model//2, n_pos_x, n_pos_y, is_learned=False)
59
+
60
+ self.mid_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
61
+ if self.hint2regress:
62
+ self.trg_word_emb = nn.Linear(d_model+2+1, d_model, bias=False)
63
+ self.trg_word_prj = nn.Linear(d_model, 2, bias=False)
64
+ else:
65
+ self.trg_word_emb = nn.Linear(d_model+self.n_vocab+1, d_model, bias=False)
66
+ self.trg_word_prj = nn.Linear(d_model, self.n_vocab, bias=False)
67
+
68
+ self.colorLabeler = colorLabeler
69
+ anchor_mode = 'random' if random_hint else 'clustering'
70
+ self.anchorGen = anchor_gen.AnchorAnalysis(mode=anchor_mode, colorLabeler=self.colorLabeler)
71
+ self._reset_parameters()
72
+
73
+ def _reset_parameters(self):
74
+ for p in self.parameters():
75
+ if p.dim() > 1:
76
+ nn.init.xavier_uniform_(p)
77
+
78
+ def load_and_froze_weight(self, checkpt_path):
79
+ data_dict = torch.load(checkpt_path, map_location=torch.device('cpu'))
80
+ '''
81
+ for param_tensor in data_dict['state_dict']:
82
+ print(param_tensor,'\t',data_dict['state_dict'][param_tensor].size())
83
+ '''
84
+ self.segnet.load_state_dict(data_dict['state_dict'])
85
+ for name, param in self.segnet.named_parameters():
86
+ param.requires_grad = False
87
+ self.segnet.eval()
88
+
89
+ def set_train(self):
90
+ ## running mode only affect certain modules, e.g. Dropout, BN, etc.
91
+ self.repnet.train()
92
+ self.wildpath.train()
93
+ self.hintpath.train()
94
+ if self.enhanced:
95
+ self.enhanceNet.train()
96
+
97
+ def get_entry_mask(self, mask_tensor):
98
+ if mask_tensor is None:
99
+ return None
100
+ ## flatten (N,1,H,W) to (N,HW)
101
+ return mask_tensor.flatten(1)
102
+
103
+ def forward(self, input_grays, input_colors, n_anchors=8, sampled_T=0):
104
+ '''
105
+ Notice: function was customized for inferece only
106
+ '''
107
+ affinity_map = self.segnet(input_grays)
108
+ pred_feats = self.repnet(input_grays)
109
+ if self.spix_pos:
110
+ full_pos_feats = self.pos_enc(pred_feats)
111
+ proxy_feats = torch.cat([pred_feats, input_colors, full_pos_feats], dim=1)
112
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
113
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
114
+ spix_colors = pooled_proxy_feats[:,64:66,:,:]
115
+ pos_feats = pooled_proxy_feats[:,66:,:,:]
116
+ else:
117
+ proxy_feats = torch.cat([pred_feats, input_colors], dim=1)
118
+ pooled_proxy_feats, conf_sum = basic.poolfeat(proxy_feats, affinity_map, self.sp_size, self.sp_size, True)
119
+ feat_tokens = pooled_proxy_feats[:,:64,:,:]
120
+ spix_colors = pooled_proxy_feats[:,64:,:,:]
121
+ pos_feats = self.pos_enc(feat_tokens)
122
+
123
+ token_labels = torch.max(self.colorLabeler.encode_ab2ind(spix_colors), dim=1, keepdim=True)[1]
124
+ spixel_sizes = basic.get_spixel_size(affinity_map, self.sp_size, self.sp_size)
125
+ all_one_map = torch.ones(spixel_sizes.shape, device=input_grays.device)
126
+ empty_entries = torch.where(spixel_sizes < 25/(self.sp_size**2), all_one_map, 1-all_one_map)
127
+ src_pad_mask = self.get_entry_mask(empty_entries) if self.use_token_mask else None
128
+ trg_pad_mask = src_pad_mask
129
+
130
+ ## parallel prob
131
+ N,C,H,W = feat_tokens.shape
132
+ ## (N,C,H,W) -> (HW,N,C)
133
+ src_pos_seq = pos_feats.flatten(2).permute(2, 0, 1)
134
+ src_seq = feat_tokens.flatten(2).permute(2, 0, 1)
135
+ ## color prob branch
136
+ enc_out, _ = self.wildpath(src_seq, src_pos_seq, src_pad_mask)
137
+ pal_logit = self.mid_word_prj(enc_out)
138
+ pal_logit = pal_logit.permute(1, 2, 0).view(N,self.n_vocab,H,W)
139
+
140
+ ## seed prob branch
141
+ ## mask(N,1,H,W): sample anchors at clustering layers
142
+ color_feat = enc_out.permute(1, 2, 0).view(N,C,H,W)
143
+ hint_mask, cluster_mask = self.anchorGen(color_feat, n_anchors, spixel_sizes, use_sklearn_kmeans=False)
144
+ pred_prob = torch.softmax(pal_logit, dim=1)
145
+ color_feat2 = src_seq.permute(1, 2, 0).view(N,C,H,W)
146
+ #pred_prob, adj_matrix = self.anchorGen._detect_correlation(color_feat, pred_prob, hint_mask, thres=0.1)
147
+ if sampled_T < 0:
148
+ ## GT anchor colors
149
+ sampled_spix_colors = spix_colors
150
+ elif sampled_T > 0:
151
+ top1_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=0)
152
+ top2_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=1)
153
+ top3_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=2)
154
+ ## duplicate meta tensors
155
+ sampled_spix_colors = torch.cat((top1_spix_colors,top2_spix_colors,top3_spix_colors), dim=0)
156
+ N = 3*N
157
+ input_grays = input_grays.expand(N,-1,-1,-1)
158
+ hint_mask = hint_mask.expand(N,-1,-1,-1)
159
+ affinity_map = affinity_map.expand(N,-1,-1,-1)
160
+ src_seq = src_seq.expand(-1, N,-1)
161
+ src_pos_seq = src_pos_seq.expand(-1, N,-1)
162
+ else:
163
+ sampled_spix_colors = self.anchorGen._sample_anchor_colors(pred_prob, hint_mask, T=sampled_T)
164
+ ## debug: controllable
165
+ if False:
166
+ hint_mask, sampled_spix_colors = basic.io_user_control(hint_mask, spix_colors, output=False)
167
+
168
+ sampled_token_labels = torch.max(self.colorLabeler.encode_ab2ind(sampled_spix_colors), dim=1, keepdim=True)[1]
169
+
170
+ ## hint based prediction
171
+ ## (N,C,H,W) -> (HW,N,C)
172
+ mask_seq = hint_mask.flatten(2).permute(2, 0, 1)
173
+ if self.hint2regress:
174
+ spix_colors_ = sampled_spix_colors
175
+ gt_seq = spix_colors_.flatten(2).permute(2, 0, 1)
176
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * gt_seq, mask_seq], dim=2))
177
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
178
+ else:
179
+ token_labels_ = sampled_token_labels
180
+ label_map = F.one_hot(token_labels_, num_classes=313).squeeze(1).float()
181
+ label_seq = label_map.permute(0, 3, 1, 2).flatten(2).permute(2, 0, 1)
182
+ hint_seq = self.trg_word_emb(torch.cat([src_seq, mask_seq * label_seq, mask_seq], dim=2))
183
+ dec_out, _ = self.hintpath(hint_seq, src_pos_seq, src_pad_mask)
184
+ ref_logit = self.trg_word_prj(dec_out)
185
+ Ct = 2 if self.hint2regress else self.n_vocab
186
+ ref_logit = ref_logit.permute(1, 2, 0).view(N,Ct,H,W)
187
+
188
+ ## pixelwise enhancement
189
+ pred_colors = None
190
+ if self.enhanced:
191
+ proc_feats = dec_out.permute(1, 2, 0).view(N,64,H,W)
192
+ full_feats = basic.upfeat(proc_feats, affinity_map, self.sp_size, self.sp_size)
193
+ pred_colors = self.enhanceNet(torch.cat((input_grays,full_feats), dim=1))
194
+ pred_colors = torch.tanh(pred_colors)
195
+
196
+ return pal_logit, ref_logit, pred_colors, affinity_map, spix_colors, hint_mask
models/network.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ import torchvision
6
+ import torch.nn.utils.spectral_norm as spectral_norm
7
+ import math
8
+
9
+
10
+ class ConvBlock(nn.Module):
11
+ def __init__(self, inChannels, outChannels, convNum, normLayer=None):
12
+ super(ConvBlock, self).__init__()
13
+ self.inConv = nn.Sequential(
14
+ nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1),
15
+ nn.ReLU(inplace=True)
16
+ )
17
+ layers = []
18
+ for _ in range(convNum - 1):
19
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
20
+ layers.append(nn.ReLU(inplace=True))
21
+ if not (normLayer is None):
22
+ layers.append(normLayer(outChannels))
23
+ self.conv = nn.Sequential(*layers)
24
+
25
+ def forward(self, x):
26
+ x = self.inConv(x)
27
+ x = self.conv(x)
28
+ return x
29
+
30
+
31
+ class ResidualBlock(nn.Module):
32
+ def __init__(self, channels, normLayer=None):
33
+ super(ResidualBlock, self).__init__()
34
+ layers = []
35
+ layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
36
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
37
+ if not (normLayer is None):
38
+ layers.append(normLayer(channels))
39
+ layers.append(nn.ReLU(inplace=True))
40
+ layers.append(nn.Conv2d(channels, channels, kernel_size=3, padding=1))
41
+ if not (normLayer is None):
42
+ layers.append(normLayer(channels))
43
+ self.conv = nn.Sequential(*layers)
44
+
45
+ def forward(self, x):
46
+ residual = self.conv(x)
47
+ return F.relu(x + residual, inplace=True)
48
+
49
+
50
+ class ResidualBlockSN(nn.Module):
51
+ def __init__(self, channels, normLayer=None):
52
+ super(ResidualBlockSN, self).__init__()
53
+ layers = []
54
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
55
+ layers.append(nn.LeakyReLU(0.2, True))
56
+ layers.append(spectral_norm(nn.Conv2d(channels, channels, kernel_size=3, padding=1)))
57
+ if not (normLayer is None):
58
+ layers.append(normLayer(channels))
59
+ self.conv = nn.Sequential(*layers)
60
+
61
+ def forward(self, x):
62
+ residual = self.conv(x)
63
+ return F.leaky_relu(x + residual, 2e-1, inplace=True)
64
+
65
+
66
+ class DownsampleBlock(nn.Module):
67
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
68
+ super(DownsampleBlock, self).__init__()
69
+ layers = []
70
+ layers.append(nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=2))
71
+ layers.append(nn.ReLU(inplace=True))
72
+ for _ in range(convNum - 1):
73
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
74
+ layers.append(nn.ReLU(inplace=True))
75
+ if not (normLayer is None):
76
+ layers.append(normLayer(outChannels))
77
+ self.conv = nn.Sequential(*layers)
78
+
79
+ def forward(self, x):
80
+ return self.conv(x)
81
+
82
+
83
+ class UpsampleBlock(nn.Module):
84
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
85
+ super(UpsampleBlock, self).__init__()
86
+ self.conv1 = nn.Conv2d(inChannels, outChannels, kernel_size=3, padding=1, stride=1)
87
+ self.combine = nn.Conv2d(2 * outChannels, outChannels, kernel_size=3, padding=1)
88
+ layers = []
89
+ for _ in range(convNum - 1):
90
+ layers.append(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1))
91
+ layers.append(nn.ReLU(inplace=True))
92
+ if not (normLayer is None):
93
+ layers.append(normLayer(outChannels))
94
+ self.conv2 = nn.Sequential(*layers)
95
+
96
+ def forward(self, x, x0):
97
+ x = self.conv1(x)
98
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
99
+ x = self.combine(torch.cat((x, x0), 1))
100
+ x = F.relu(x)
101
+ return self.conv2(x)
102
+
103
+
104
+ class UpsampleBlockSN(nn.Module):
105
+ def __init__(self, inChannels, outChannels, convNum=2, normLayer=None):
106
+ super(UpsampleBlockSN, self).__init__()
107
+ self.conv1 = spectral_norm(nn.Conv2d(inChannels, outChannels, kernel_size=3, stride=1, padding=1))
108
+ self.shortcut = spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, stride=1, padding=1))
109
+ layers = []
110
+ for _ in range(convNum - 1):
111
+ layers.append(spectral_norm(nn.Conv2d(outChannels, outChannels, kernel_size=3, padding=1)))
112
+ layers.append(nn.LeakyReLU(0.2, True))
113
+ if not (normLayer is None):
114
+ layers.append(normLayer(outChannels))
115
+ self.conv2 = nn.Sequential(*layers)
116
+
117
+ def forward(self, x, x0):
118
+ x = self.conv1(x)
119
+ x = F.interpolate(x, scale_factor=2, mode='nearest')
120
+ x = x + self.shortcut(x0)
121
+ x = F.leaky_relu(x, 2e-1)
122
+ return self.conv2(x)
123
+
124
+
125
+ class HourGlass2(nn.Module):
126
+ def __init__(self, inChannel=3, outChannel=1, resNum=3, normLayer=None):
127
+ super(HourGlass2, self).__init__()
128
+ self.inConv = ConvBlock(inChannel, 64, convNum=2, normLayer=normLayer)
129
+ self.down1 = DownsampleBlock(64, 128, convNum=2, normLayer=normLayer)
130
+ self.down2 = DownsampleBlock(128, 256, convNum=2, normLayer=normLayer)
131
+ self.residual = nn.Sequential(*[ResidualBlock(256) for _ in range(resNum)])
132
+ self.up2 = UpsampleBlock(256, 128, convNum=3, normLayer=normLayer)
133
+ self.up1 = UpsampleBlock(128, 64, convNum=3, normLayer=normLayer)
134
+ self.outConv = nn.Conv2d(64, outChannel, kernel_size=3, padding=1)
135
+
136
+ def forward(self, x):
137
+ f1 = self.inConv(x)
138
+ f2 = self.down1(f1)
139
+ f3 = self.down2(f2)
140
+ r3 = self.residual(f3)
141
+ r2 = self.up2(r3, f2)
142
+ r1 = self.up1(r2, f1)
143
+ y = self.outConv(r1)
144
+ return y
145
+
146
+
147
+ class ColorProbNet(nn.Module):
148
+ def __init__(self, inChannel=1, outChannel=2, with_SA=False):
149
+ super(ColorProbNet, self).__init__()
150
+ BNFunc = nn.BatchNorm2d
151
+ # conv1: 256
152
+ conv1_2 = [spectral_norm(nn.Conv2d(inChannel, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
153
+ conv1_2 += [spectral_norm(nn.Conv2d(64, 64, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
154
+ conv1_2 += [BNFunc(64, affine=True)]
155
+ # conv2: 128
156
+ conv2_3 = [spectral_norm(nn.Conv2d(64, 128, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
157
+ conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
158
+ conv2_3 += [spectral_norm(nn.Conv2d(128, 128, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
159
+ conv2_3 += [BNFunc(128, affine=True)]
160
+ # conv3: 64
161
+ conv3_3 = [spectral_norm(nn.Conv2d(128, 256, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
162
+ conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
163
+ conv3_3 += [spectral_norm(nn.Conv2d(256, 256, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
164
+ conv3_3 += [BNFunc(256, affine=True)]
165
+ # conv4: 32
166
+ conv4_3 = [spectral_norm(nn.Conv2d(256, 512, 3, stride=2, padding=1)), nn.LeakyReLU(0.2, True),]
167
+ conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
168
+ conv4_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
169
+ conv4_3 += [BNFunc(512, affine=True)]
170
+ # conv5: 32
171
+ conv5_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
172
+ conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
173
+ conv5_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
174
+ conv5_3 += [BNFunc(512, affine=True)]
175
+ # conv6: 32
176
+ conv6_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
177
+ conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
178
+ conv6_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
179
+ conv6_3 += [BNFunc(512, affine=True),]
180
+ if with_SA:
181
+ conv6_3 += [Self_Attn(512)]
182
+ # conv7: 32
183
+ conv7_3 = [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
184
+ conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
185
+ conv7_3 += [spectral_norm(nn.Conv2d(512, 512, 3, stride=1, padding=1)), nn.LeakyReLU(0.2, True),]
186
+ conv7_3 += [BNFunc(512, affine=True)]
187
+ # conv8: 64
188
+ conv8up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(512, 256, 3, stride=1, padding=1),]
189
+ conv3short8 = [nn.Conv2d(256, 256, 3, stride=1, padding=1),]
190
+ conv8_3 = [nn.ReLU(True),]
191
+ conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
192
+ conv8_3 += [nn.Conv2d(256, 256, 3, stride=1, padding=1), nn.ReLU(True),]
193
+ conv8_3 += [BNFunc(256, affine=True),]
194
+ # conv9: 128
195
+ conv9up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(256, 128, 3, stride=1, padding=1),]
196
+ conv9_2 = [nn.Conv2d(128, 128, 3, stride=1, padding=1), nn.ReLU(True),]
197
+ conv9_2 += [BNFunc(128, affine=True)]
198
+ # conv10: 64
199
+ conv10up = [nn.Upsample(scale_factor=2, mode='nearest'), nn.Conv2d(128, 64, 3, stride=1, padding=1),]
200
+ conv10_2 = [nn.ReLU(True),]
201
+ conv10_2 += [nn.Conv2d(64, outChannel, 3, stride=1, padding=1), nn.ReLU(True),]
202
+
203
+ self.conv1_2 = nn.Sequential(*conv1_2)
204
+ self.conv2_3 = nn.Sequential(*conv2_3)
205
+ self.conv3_3 = nn.Sequential(*conv3_3)
206
+ self.conv4_3 = nn.Sequential(*conv4_3)
207
+ self.conv5_3 = nn.Sequential(*conv5_3)
208
+ self.conv6_3 = nn.Sequential(*conv6_3)
209
+ self.conv7_3 = nn.Sequential(*conv7_3)
210
+ self.conv8up = nn.Sequential(*conv8up)
211
+ self.conv3short8 = nn.Sequential(*conv3short8)
212
+ self.conv8_3 = nn.Sequential(*conv8_3)
213
+ self.conv9up = nn.Sequential(*conv9up)
214
+ self.conv9_2 = nn.Sequential(*conv9_2)
215
+ self.conv10up = nn.Sequential(*conv10up)
216
+ self.conv10_2 = nn.Sequential(*conv10_2)
217
+ # claffificaton output
218
+ #self.model_class = nn.Sequential(*[nn.Conv2d(256, 313, kernel_size=1, padding=0, stride=1),])
219
+
220
+ def forward(self, input_grays):
221
+ f1_2 = self.conv1_2(input_grays)
222
+ f2_3 = self.conv2_3(f1_2)
223
+ f3_3 = self.conv3_3(f2_3)
224
+ f4_3 = self.conv4_3(f3_3)
225
+ f5_3 = self.conv5_3(f4_3)
226
+ f6_3 = self.conv6_3(f5_3)
227
+ f7_3 = self.conv7_3(f6_3)
228
+ f8_up = self.conv8up(f7_3) + self.conv3short8(f3_3)
229
+ f8_3 = self.conv8_3(f8_up)
230
+ f9_up = self.conv9up(f8_3)
231
+ f9_2 = self.conv9_2(f9_up)
232
+ f10_up = self.conv10up(f9_2)
233
+ f10_2 = self.conv10_2(f10_up)
234
+ out_feats = f10_2
235
+ #out_probs = self.model_class(f8_3)
236
+ return out_feats
237
+
238
+
239
+
240
+ def conv(batchNorm, in_planes, out_planes, kernel_size=3, stride=1):
241
+ if batchNorm:
242
+ return nn.Sequential(
243
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=False),
244
+ nn.BatchNorm2d(out_planes),
245
+ nn.LeakyReLU(0.1)
246
+ )
247
+ else:
248
+ return nn.Sequential(
249
+ nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=(kernel_size-1)//2, bias=True),
250
+ nn.LeakyReLU(0.1)
251
+ )
252
+
253
+
254
+ def deconv(in_planes, out_planes):
255
+ return nn.Sequential(
256
+ nn.ConvTranspose2d(in_planes, out_planes, kernel_size=4, stride=2, padding=1, bias=True),
257
+ nn.LeakyReLU(0.1)
258
+ )
259
+
260
+ class SpixelNet(nn.Module):
261
+ def __init__(self, inChannel=3, outChannel=9, batchNorm=True):
262
+ super(SpixelNet,self).__init__()
263
+ self.batchNorm = batchNorm
264
+ self.conv0a = conv(self.batchNorm, inChannel, 16, kernel_size=3)
265
+ self.conv0b = conv(self.batchNorm, 16, 16, kernel_size=3)
266
+ self.conv1a = conv(self.batchNorm, 16, 32, kernel_size=3, stride=2)
267
+ self.conv1b = conv(self.batchNorm, 32, 32, kernel_size=3)
268
+ self.conv2a = conv(self.batchNorm, 32, 64, kernel_size=3, stride=2)
269
+ self.conv2b = conv(self.batchNorm, 64, 64, kernel_size=3)
270
+ self.conv3a = conv(self.batchNorm, 64, 128, kernel_size=3, stride=2)
271
+ self.conv3b = conv(self.batchNorm, 128, 128, kernel_size=3)
272
+ self.conv4a = conv(self.batchNorm, 128, 256, kernel_size=3, stride=2)
273
+ self.conv4b = conv(self.batchNorm, 256, 256, kernel_size=3)
274
+ self.deconv3 = deconv(256, 128)
275
+ self.conv3_1 = conv(self.batchNorm, 256, 128)
276
+ self.deconv2 = deconv(128, 64)
277
+ self.conv2_1 = conv(self.batchNorm, 128, 64)
278
+ self.deconv1 = deconv(64, 32)
279
+ self.conv1_1 = conv(self.batchNorm, 64, 32)
280
+ self.deconv0 = deconv(32, 16)
281
+ self.conv0_1 = conv(self.batchNorm, 32, 16)
282
+ self.pred_mask0 = nn.Conv2d(16, outChannel, kernel_size=3, stride=1, padding=1, bias=True)
283
+ self.softmax = nn.Softmax(1)
284
+ for m in self.modules():
285
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
286
+ init.kaiming_normal_(m.weight, 0.1)
287
+ if m.bias is not None:
288
+ init.constant_(m.bias, 0)
289
+ elif isinstance(m, nn.BatchNorm2d):
290
+ init.constant_(m.weight, 1)
291
+ init.constant_(m.bias, 0)
292
+
293
+ def forward(self, x):
294
+ out1 = self.conv0b(self.conv0a(x)) #5*5
295
+ out2 = self.conv1b(self.conv1a(out1)) #11*11
296
+ out3 = self.conv2b(self.conv2a(out2)) #23*23
297
+ out4 = self.conv3b(self.conv3a(out3)) #47*47
298
+ out5 = self.conv4b(self.conv4a(out4)) #95*95
299
+ out_deconv3 = self.deconv3(out5)
300
+ concat3 = torch.cat((out4, out_deconv3), 1)
301
+ out_conv3_1 = self.conv3_1(concat3)
302
+ out_deconv2 = self.deconv2(out_conv3_1)
303
+ concat2 = torch.cat((out3, out_deconv2), 1)
304
+ out_conv2_1 = self.conv2_1(concat2)
305
+ out_deconv1 = self.deconv1(out_conv2_1)
306
+ concat1 = torch.cat((out2, out_deconv1), 1)
307
+ out_conv1_1 = self.conv1_1(concat1)
308
+ out_deconv0 = self.deconv0(out_conv1_1)
309
+ concat0 = torch.cat((out1, out_deconv0), 1)
310
+ out_conv0_1 = self.conv0_1(concat0)
311
+ mask0 = self.pred_mask0(out_conv0_1)
312
+ prob0 = self.softmax(mask0)
313
+ return prob0
314
+
315
+
316
+
317
+ ## VGG architecter, used for the perceptual loss using a pretrained VGG network
318
+ class VGG19(torch.nn.Module):
319
+ def __init__(self, requires_grad=False, local_pretrained_path='checkpoints/vgg19.pth'):
320
+ super().__init__()
321
+ #vgg_pretrained_features = torchvision.models.vgg19(pretrained=True).features
322
+ model = torchvision.models.vgg19()
323
+ model.load_state_dict(torch.load(local_pretrained_path))
324
+ vgg_pretrained_features = model.features
325
+
326
+ self.slice1 = torch.nn.Sequential()
327
+ self.slice2 = torch.nn.Sequential()
328
+ self.slice3 = torch.nn.Sequential()
329
+ self.slice4 = torch.nn.Sequential()
330
+ self.slice5 = torch.nn.Sequential()
331
+ for x in range(2):
332
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
333
+ for x in range(2, 7):
334
+ self.slice2.add_module(str(x), vgg_pretrained_features[x])
335
+ for x in range(7, 12):
336
+ self.slice3.add_module(str(x), vgg_pretrained_features[x])
337
+ for x in range(12, 21):
338
+ self.slice4.add_module(str(x), vgg_pretrained_features[x])
339
+ for x in range(21, 30):
340
+ self.slice5.add_module(str(x), vgg_pretrained_features[x])
341
+ if not requires_grad:
342
+ for param in self.parameters():
343
+ param.requires_grad = False
344
+
345
+ def forward(self, X):
346
+ h_relu1 = self.slice1(X)
347
+ h_relu2 = self.slice2(h_relu1)
348
+ h_relu3 = self.slice3(h_relu2)
349
+ h_relu4 = self.slice4(h_relu3)
350
+ h_relu5 = self.slice5(h_relu4)
351
+ out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
352
+ return out
models/position_encoding.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
2
+ """
3
+ Various positional encodings for the transformer.
4
+ """
5
+ import math
6
+ import torch
7
+ from torch import nn
8
+
9
+
10
+ class PositionEmbeddingSine(nn.Module):
11
+ """
12
+ This is a more standard version of the position embedding, very similar to the one
13
+ used by the Attention is all you need paper, generalized to work on images.
14
+ """
15
+ def __init__(self, num_pos_feats=64, temperature=10000, normalize=False, scale=None):
16
+ super().__init__()
17
+ self.num_pos_feats = num_pos_feats
18
+ self.temperature = temperature
19
+ self.normalize = normalize
20
+ if scale is not None and normalize is False:
21
+ raise ValueError("normalize should be True if scale is passed")
22
+ if scale is None:
23
+ scale = 2 * math.pi
24
+ self.scale = scale
25
+
26
+ def forward(self, token_tensors):
27
+ ## input: (B,C,H,W)
28
+ x = token_tensors
29
+ h, w = x.shape[-2:]
30
+ identity_map= torch.ones((h,w), device=x.device)
31
+ y_embed = identity_map.cumsum(0, dtype=torch.float32)
32
+ x_embed = identity_map.cumsum(1, dtype=torch.float32)
33
+ if self.normalize:
34
+ eps = 1e-6
35
+ y_embed = y_embed / (y_embed[-1:, :] + eps) * self.scale
36
+ x_embed = x_embed / (x_embed[:, -1:] + eps) * self.scale
37
+
38
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
39
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
40
+
41
+ pos_x = x_embed[:, :, None] / dim_t
42
+ pos_y = y_embed[:, :, None] / dim_t
43
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
44
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
45
+ pos = torch.cat((pos_y, pos_x), dim=2).permute(2, 0, 1)
46
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
47
+ return batch_pos
48
+
49
+
50
+ class PositionEmbeddingLearned(nn.Module):
51
+ """
52
+ Absolute pos embedding, learned.
53
+ """
54
+ def __init__(self, n_pos_x=16, n_pos_y=16, num_pos_feats=64):
55
+ super().__init__()
56
+ self.row_embed = nn.Embedding(n_pos_y, num_pos_feats)
57
+ self.col_embed = nn.Embedding(n_pos_x, num_pos_feats)
58
+ self.reset_parameters()
59
+
60
+ def reset_parameters(self):
61
+ nn.init.uniform_(self.row_embed.weight)
62
+ nn.init.uniform_(self.col_embed.weight)
63
+
64
+ def forward(self, token_tensors):
65
+ ## input: (B,C,H,W)
66
+ x = token_tensors
67
+ h, w = x.shape[-2:]
68
+ i = torch.arange(w, device=x.device)
69
+ j = torch.arange(h, device=x.device)
70
+ x_emb = self.col_embed(i)
71
+ y_emb = self.row_embed(j)
72
+ pos = torch.cat([
73
+ x_emb.unsqueeze(0).repeat(h, 1, 1),
74
+ y_emb.unsqueeze(1).repeat(1, w, 1),
75
+ ], dim=-1).permute(2, 0, 1)
76
+ batch_pos = pos.unsqueeze(0).repeat(x.shape[0], 1, 1, 1)
77
+ return batch_pos
78
+
79
+
80
+ def build_position_encoding(num_pos_feats=64, n_pos_x=16, n_pos_y=16, is_learned=False):
81
+ if is_learned:
82
+ position_embedding = PositionEmbeddingLearned(n_pos_x, n_pos_y, num_pos_feats)
83
+ else:
84
+ position_embedding = PositionEmbeddingSine(num_pos_feats, normalize=True)
85
+
86
+ return position_embedding
models/transformer2d.py ADDED
@@ -0,0 +1,229 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+ import copy, math
5
+ from models.position_encoding import build_position_encoding
6
+
7
+
8
+ class TransformerEncoder(nn.Module):
9
+
10
+ def __init__(self, enc_layer, num_layers, use_dense_pos=False):
11
+ super().__init__()
12
+ self.layers = nn.ModuleList([copy.deepcopy(enc_layer) for i in range(num_layers)])
13
+ self.num_layers = num_layers
14
+ self.use_dense_pos = use_dense_pos
15
+
16
+ def forward(self, src, pos, padding_mask=None):
17
+ if self.use_dense_pos:
18
+ ## pos encoding at each MH-Attention block (q,k)
19
+ output, pos_enc = src, pos
20
+ for layer in self.layers:
21
+ output, att_map = layer(output, pos_enc, padding_mask)
22
+ else:
23
+ ## pos encoding at input only (q,k,v)
24
+ output, pos_enc = src + pos, None
25
+ for layer in self.layers:
26
+ output, att_map = layer(output, pos_enc, padding_mask)
27
+ return output, att_map
28
+
29
+
30
+ class EncoderLayer(nn.Module):
31
+
32
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
33
+ use_dense_pos=False):
34
+ super().__init__()
35
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
36
+ # Implementation of Feedforward model
37
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
38
+ self.dropout = nn.Dropout(dropout)
39
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
40
+
41
+ self.norm1 = nn.LayerNorm(d_model)
42
+ self.norm2 = nn.LayerNorm(d_model)
43
+ self.dropout1 = nn.Dropout(dropout)
44
+ self.dropout2 = nn.Dropout(dropout)
45
+
46
+ self.activation = _get_activation_fn(activation)
47
+
48
+ def with_pos_embed(self, tensor, pos):
49
+ return tensor if pos is None else tensor + pos
50
+
51
+ def forward(self, src, pos, padding_mask):
52
+ q = k = self.with_pos_embed(src, pos)
53
+ src2, attn = self.self_attn(q, k, value=src, key_padding_mask=padding_mask)
54
+ src = src + self.dropout1(src2)
55
+ src = self.norm1(src)
56
+ src2 = self.linear2(self.dropout(self.activation(self.linear1(src))))
57
+ src = src + self.dropout2(src2)
58
+ src = self.norm2(src)
59
+ return src, attn
60
+
61
+
62
+ class TransformerDecoder(nn.Module):
63
+
64
+ def __init__(self, dec_layer, num_layers, use_dense_pos=False, return_intermediate=False):
65
+ super().__init__()
66
+ self.layers = nn.ModuleList([copy.deepcopy(dec_layer) for i in range(num_layers)])
67
+ self.num_layers = num_layers
68
+ self.use_dense_pos = use_dense_pos
69
+ self.return_intermediate = return_intermediate
70
+
71
+ def forward(self, tgt, tgt_pos, memory, memory_pos,
72
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask=None):
73
+ intermediate = []
74
+ if self.use_dense_pos:
75
+ ## pos encoding at each MH-Attention block (q,k)
76
+ output = tgt
77
+ tgt_pos_enc, memory_pos_enc = tgt_pos, memory_pos
78
+ for layer in self.layers:
79
+ output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
80
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask)
81
+ if self.return_intermediate:
82
+ intermediate.append(output)
83
+ else:
84
+ ## pos encoding at input only (q,k,v)
85
+ output = tgt + tgt_pos
86
+ tgt_pos_enc, memory_pos_enc = None, None
87
+ for layer in self.layers:
88
+ output, att_map = layer(output, tgt_pos_enc, memory, memory_pos_enc,
89
+ tgt_padding_mask, src_padding_mask, tgt_attn_mask)
90
+ if self.return_intermediate:
91
+ intermediate.append(output)
92
+
93
+ if self.return_intermediate:
94
+ return torch.stack(intermediate)
95
+ return output, att_map
96
+
97
+
98
+ class DecoderLayer(nn.Module):
99
+
100
+ def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation="relu",
101
+ use_dense_pos=False):
102
+ super().__init__()
103
+ self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
104
+ self.corr_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
105
+ # Implementation of Feedforward model
106
+ self.linear1 = nn.Linear(d_model, dim_feedforward)
107
+ self.dropout = nn.Dropout(dropout)
108
+ self.linear2 = nn.Linear(dim_feedforward, d_model)
109
+
110
+ self.norm1 = nn.LayerNorm(d_model)
111
+ self.norm2 = nn.LayerNorm(d_model)
112
+ self.norm3 = nn.LayerNorm(d_model)
113
+ self.dropout1 = nn.Dropout(dropout)
114
+ self.dropout2 = nn.Dropout(dropout)
115
+ self.dropout3 = nn.Dropout(dropout)
116
+
117
+ self.activation = _get_activation_fn(activation)
118
+
119
+ def with_pos_embed(self, tensor, pos):
120
+ return tensor if pos is None else tensor + pos
121
+
122
+ def forward(self, tgt, tgt_pos, memory, memory_pos,
123
+ tgt_padding_mask, memory_padding_mask, tgt_attn_mask):
124
+ q = k = self.with_pos_embed(tgt, tgt_pos)
125
+ tgt2, attn = self.self_attn(q, k, value=tgt, key_padding_mask=tgt_padding_mask,
126
+ attn_mask=tgt_attn_mask)
127
+ tgt = tgt + self.dropout1(tgt2)
128
+ tgt = self.norm1(tgt)
129
+ tgt2, attn = self.corr_attn(query=self.with_pos_embed(tgt, tgt_pos),
130
+ key=self.with_pos_embed(memory, memory_pos),
131
+ value=memory, key_padding_mask=memory_padding_mask)
132
+ tgt = tgt + self.dropout2(tgt2)
133
+ tgt = self.norm2(tgt)
134
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
135
+ tgt = tgt + self.dropout3(tgt2)
136
+ tgt = self.norm3(tgt)
137
+ return tgt, attn
138
+
139
+
140
+ def _get_activation_fn(activation):
141
+ """Return an activation function given a string"""
142
+ if activation == "relu":
143
+ return F.relu
144
+ if activation == "gelu":
145
+ return F.gelu
146
+ if activation == "glu":
147
+ return F.glu
148
+ raise RuntimeError(F"activation should be relu/gelu, not {activation}.")
149
+
150
+
151
+
152
+ #-----------------------------------------------------------------------------------
153
+ '''
154
+ copy from the implementatoin of "attention-is-all-you-need-pytorch-master" by Yu-Hsiang Huang
155
+ '''
156
+
157
+ class MultiHeadAttention(nn.Module):
158
+ ''' Multi-Head Attention module '''
159
+
160
+ def __init__(self, n_head, d_model, d_k, d_v, dropout=0.1):
161
+ super().__init__()
162
+
163
+ self.n_head = n_head
164
+ self.d_k = d_k
165
+ self.d_v = d_v
166
+
167
+ self.w_qs = nn.Linear(d_model, n_head * d_k, bias=False)
168
+ self.w_ks = nn.Linear(d_model, n_head * d_k, bias=False)
169
+ self.w_vs = nn.Linear(d_model, n_head * d_v, bias=False)
170
+ self.fc = nn.Linear(n_head * d_v, d_model, bias=False)
171
+
172
+ self.attention = ScaledDotProductAttention(temperature=d_k ** 0.5)
173
+
174
+ self.dropout = nn.Dropout(dropout)
175
+ self.layer_norm = nn.LayerNorm(d_model, eps=1e-6)
176
+
177
+
178
+ def forward(self, q, k, v, mask=None):
179
+
180
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
181
+ sz_b, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
182
+
183
+ residual = q
184
+
185
+ # Pass through the pre-attention projection: b x lq x (n*dv)
186
+ # Separate different heads: b x lq x n x dv
187
+ q = self.w_qs(q).view(sz_b, len_q, n_head, d_k)
188
+ k = self.w_ks(k).view(sz_b, len_k, n_head, d_k)
189
+ v = self.w_vs(v).view(sz_b, len_v, n_head, d_v)
190
+
191
+ # Transpose for attention dot product: b x n x lq x dv
192
+ q, k, v = q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2)
193
+
194
+ if mask is not None:
195
+ mask = mask.unsqueeze(1) # For head axis broadcasting.
196
+
197
+ q, attn = self.attention(q, k, v, mask=mask)
198
+
199
+ # Transpose to move the head dimension back: b x lq x n x dv
200
+ # Combine the last two dimensions to concatenate all the heads together: b x lq x (n*dv)
201
+ q = q.transpose(1, 2).contiguous().view(sz_b, len_q, -1)
202
+ q = self.dropout(self.fc(q))
203
+ q += residual
204
+
205
+ q = self.layer_norm(q)
206
+
207
+ return q, attn
208
+
209
+
210
+
211
+ class ScaledDotProductAttention(nn.Module):
212
+ ''' Scaled Dot-Product Attention '''
213
+
214
+ def __init__(self, temperature, attn_dropout=0.1):
215
+ super().__init__()
216
+ self.temperature = temperature
217
+ self.dropout = nn.Dropout(attn_dropout)
218
+
219
+ def forward(self, q, k, v, mask=None):
220
+
221
+ attn = torch.matmul(q / self.temperature, k.transpose(2, 3))
222
+
223
+ if mask is not None:
224
+ attn = attn.masked_fill(mask == 0, -1e9)
225
+
226
+ attn = self.dropout(F.softmax(attn, dim=-1))
227
+ output = torch.matmul(attn, v)
228
+
229
+ return output, attn