import torch import torch.nn as nn import torch.nn.functional as F import numpy as np from skimage.segmentation._slic import _enforce_label_connectivity_cython def initWave(nPeriodic): buf = [] for i in range(nPeriodic // 4+1): v = 0.5 + i / float(nPeriodic//4+1e-10) buf += [0, v, v, 0] buf += [0, -v, v, 0] #so from other quadrants as well.. buf = buf[:2*nPeriodic] awave = np.array(buf, dtype=np.float32) * np.pi awave = torch.FloatTensor(awave).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) return awave class SPADEGenerator(nn.Module): def __init__(self, hidden_dim): super().__init__() nf = hidden_dim // 16 self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf) self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf) self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf) self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf) self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf) self.up_2 = SPADEResnetBlock(4 * nf, nf) #self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf) final_nc = nf self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) self.up = nn.Upsample(scale_factor=2) def forward(self, x, input): seg = input x = self.head_0(x, seg) x = self.up(x) x = self.G_middle_0(x, seg) x = self.G_middle_1(x, seg) x = self.up(x) x = self.up_0(x, seg) x = self.up(x) x = self.up_1(x, seg) x = self.up(x) x = self.up_2(x, seg) #x = self.up(x) #x = self.up_3(x, seg) x = self.conv_img(F.leaky_relu(x, 2e-1)) return x class SPADE(nn.Module): def __init__(self, norm_nc, label_nc): super().__init__() ks = 3 self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) # The dimension of the intermediate embedding space. Yes, hardcoded. nhidden = 128 pw = ks // 2 self.mlp_shared = nn.Sequential( nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), nn.ReLU() ) self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) def forward(self, x, segmap): # Part 1. generate parameter-free normalized activations normalized = self.param_free_norm(x) # Part 2. produce scaling and bias conditioned on semantic map #segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear', align_corners = False) actv = self.mlp_shared(segmap) gamma = self.mlp_gamma(actv) beta = self.mlp_beta(actv) # apply scale and bias out = normalized * (1 + gamma) + beta return out class SPADEResnetBlock(nn.Module): def __init__(self, fin, fout): super().__init__() # Attributes self.learned_shortcut = (fin != fout) fmiddle = min(fin, fout) # create conv layers self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) if self.learned_shortcut: self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) # define normalization layers self.norm_0 = SPADE(fin, 256) self.norm_1 = SPADE(fmiddle, 256) if self.learned_shortcut: self.norm_s = SPADE(fin, 256) # note the resnet block with SPADE also takes in |seg|, # the semantic segmentation map as input def forward(self, x, seg): x_s = self.shortcut(x, seg) dx = self.conv_0(self.actvn(self.norm_0(x, seg))) dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) out = x_s + dx return out def shortcut(self, x, seg): if self.learned_shortcut: x_s = self.conv_s(self.norm_s(x, seg)) else: x_s = x return x_s def actvn(self, x): return F.leaky_relu(x, 2e-1) def get_edges(sp_label, sp_num): # This function returns a (hw) * (hw) matrix N. # If Nij = 1, then superpixel i and j are neighbors # Otherwise Nij = 0. top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] n_affs = [] edge_indices = [] for i in range(sp_label.shape[0]): # change to torch.ones below to include self-loop in graph n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).to(sp_label.device) # top/bottom top_i = top[i].squeeze() x, y = torch.nonzero(top_i, as_tuple = True) sp1 = sp_label[i, :, x, y].squeeze().long() sp2 = sp_label[i, :, x+1, y].squeeze().long() n_aff[:, sp1, sp2] = 1 n_aff[:, sp2, sp1] = 1 # left/right left_i = left[i].squeeze() try: x, y = torch.nonzero(left_i, as_tuple = True) except: import pdb; pdb.set_trace() sp1 = sp_label[i, :, x, y].squeeze().long() sp2 = sp_label[i, :, x, y+1].squeeze().long() n_aff[:, sp1, sp2] = 1 n_aff[:, sp2, sp1] = 1 # top left top_left_i = top_left[i].squeeze() x, y = torch.nonzero(top_left_i, as_tuple = True) sp1 = sp_label[i, :, x, y].squeeze().long() sp2 = sp_label[i, :, x+1, y+1].squeeze().long() n_aff[:, sp1, sp2] = 1 n_aff[:, sp2, sp1] = 1 # top right top_right_i = top_right[i].squeeze() x, y = torch.nonzero(top_right_i, as_tuple = True) sp1 = sp_label[i, :, x, y+1].squeeze().long() sp2 = sp_label[i, :, x+1, y].squeeze().long() n_aff[:, sp1, sp2] = 1 n_aff[:, sp2, sp1] = 1 n_affs.append(n_aff) edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) edge_indices.append(edge_index.to(sp_label.device)) return edge_indices, torch.cat(n_affs) def enforce_connectivity(segs, H, W, sp_num = 196, min_size = None, max_size = None): rets = [] for i in range(segs.shape[0]): seg = segs[i] seg = seg.squeeze().cpu().numpy() segment_size = H * W / sp_num if min_size is None: min_size = int(0.1 * segment_size) if max_size is None: max_size = int(1000.0 * segment_size) seg = _enforce_label_connectivity_cython(seg[None], min_size, max_size)[0] seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0) rets.append(seg) return torch.cat(rets)