Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import numpy as np | |
from skimage.segmentation._slic import _enforce_label_connectivity_cython | |
def initWave(nPeriodic): | |
buf = [] | |
for i in range(nPeriodic // 4+1): | |
v = 0.5 + i / float(nPeriodic//4+1e-10) | |
buf += [0, v, v, 0] | |
buf += [0, -v, v, 0] #so from other quadrants as well.. | |
buf = buf[:2*nPeriodic] | |
awave = np.array(buf, dtype=np.float32) * np.pi | |
awave = torch.FloatTensor(awave).unsqueeze(-1).unsqueeze(-1).unsqueeze(0) | |
return awave | |
class SPADEGenerator(nn.Module): | |
def __init__(self, hidden_dim): | |
super().__init__() | |
nf = hidden_dim // 16 | |
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf) | |
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf) | |
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf) | |
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf) | |
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf) | |
self.up_2 = SPADEResnetBlock(4 * nf, nf) | |
#self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf) | |
final_nc = nf | |
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1) | |
self.up = nn.Upsample(scale_factor=2) | |
def forward(self, x, input): | |
seg = input | |
x = self.head_0(x, seg) | |
x = self.up(x) | |
x = self.G_middle_0(x, seg) | |
x = self.G_middle_1(x, seg) | |
x = self.up(x) | |
x = self.up_0(x, seg) | |
x = self.up(x) | |
x = self.up_1(x, seg) | |
x = self.up(x) | |
x = self.up_2(x, seg) | |
#x = self.up(x) | |
#x = self.up_3(x, seg) | |
x = self.conv_img(F.leaky_relu(x, 2e-1)) | |
return x | |
class SPADE(nn.Module): | |
def __init__(self, norm_nc, label_nc): | |
super().__init__() | |
ks = 3 | |
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False) | |
# The dimension of the intermediate embedding space. Yes, hardcoded. | |
nhidden = 128 | |
pw = ks // 2 | |
self.mlp_shared = nn.Sequential( | |
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw), | |
nn.ReLU() | |
) | |
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw) | |
def forward(self, x, segmap): | |
# Part 1. generate parameter-free normalized activations | |
normalized = self.param_free_norm(x) | |
# Part 2. produce scaling and bias conditioned on semantic map | |
#segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest') | |
segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear', align_corners = False) | |
actv = self.mlp_shared(segmap) | |
gamma = self.mlp_gamma(actv) | |
beta = self.mlp_beta(actv) | |
# apply scale and bias | |
out = normalized * (1 + gamma) + beta | |
return out | |
class SPADEResnetBlock(nn.Module): | |
def __init__(self, fin, fout): | |
super().__init__() | |
# Attributes | |
self.learned_shortcut = (fin != fout) | |
fmiddle = min(fin, fout) | |
# create conv layers | |
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1) | |
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1) | |
if self.learned_shortcut: | |
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False) | |
# define normalization layers | |
self.norm_0 = SPADE(fin, 256) | |
self.norm_1 = SPADE(fmiddle, 256) | |
if self.learned_shortcut: | |
self.norm_s = SPADE(fin, 256) | |
# note the resnet block with SPADE also takes in |seg|, | |
# the semantic segmentation map as input | |
def forward(self, x, seg): | |
x_s = self.shortcut(x, seg) | |
dx = self.conv_0(self.actvn(self.norm_0(x, seg))) | |
dx = self.conv_1(self.actvn(self.norm_1(dx, seg))) | |
out = x_s + dx | |
return out | |
def shortcut(self, x, seg): | |
if self.learned_shortcut: | |
x_s = self.conv_s(self.norm_s(x, seg)) | |
else: | |
x_s = x | |
return x_s | |
def actvn(self, x): | |
return F.leaky_relu(x, 2e-1) | |
def get_edges(sp_label, sp_num): | |
# This function returns a (hw) * (hw) matrix N. | |
# If Nij = 1, then superpixel i and j are neighbors | |
# Otherwise Nij = 0. | |
top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :] | |
left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:] | |
top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:] | |
top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1] | |
n_affs = [] | |
edge_indices = [] | |
for i in range(sp_label.shape[0]): | |
# change to torch.ones below to include self-loop in graph | |
n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).to(sp_label.device) | |
# top/bottom | |
top_i = top[i].squeeze() | |
x, y = torch.nonzero(top_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# left/right | |
left_i = left[i].squeeze() | |
try: | |
x, y = torch.nonzero(left_i, as_tuple = True) | |
except: | |
import pdb; pdb.set_trace() | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x, y+1].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# top left | |
top_left_i = top_left[i].squeeze() | |
x, y = torch.nonzero(top_left_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y+1].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
# top right | |
top_right_i = top_right[i].squeeze() | |
x, y = torch.nonzero(top_right_i, as_tuple = True) | |
sp1 = sp_label[i, :, x, y+1].squeeze().long() | |
sp2 = sp_label[i, :, x+1, y].squeeze().long() | |
n_aff[:, sp1, sp2] = 1 | |
n_aff[:, sp2, sp1] = 1 | |
n_affs.append(n_aff) | |
edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True)) | |
edge_indices.append(edge_index.to(sp_label.device)) | |
return edge_indices, torch.cat(n_affs) | |
def enforce_connectivity(segs, H, W, sp_num = 196, min_size = None, max_size = None): | |
rets = [] | |
for i in range(segs.shape[0]): | |
seg = segs[i] | |
seg = seg.squeeze().cpu().numpy() | |
segment_size = H * W / sp_num | |
if min_size is None: | |
min_size = int(0.1 * segment_size) | |
if max_size is None: | |
max_size = int(1000.0 * segment_size) | |
seg = _enforce_label_connectivity_cython(seg[None], min_size, max_size)[0] | |
seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0) | |
rets.append(seg) | |
return torch.cat(rets) | |