sunshineatnoon
Add application file
1b2a9b1
raw
history blame
6.78 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from skimage.segmentation._slic import _enforce_label_connectivity_cython
def initWave(nPeriodic):
buf = []
for i in range(nPeriodic // 4+1):
v = 0.5 + i / float(nPeriodic//4+1e-10)
buf += [0, v, v, 0]
buf += [0, -v, v, 0] #so from other quadrants as well..
buf = buf[:2*nPeriodic]
awave = np.array(buf, dtype=np.float32) * np.pi
awave = torch.FloatTensor(awave).unsqueeze(-1).unsqueeze(-1).unsqueeze(0)
return awave
class SPADEGenerator(nn.Module):
def __init__(self, hidden_dim):
super().__init__()
nf = hidden_dim // 16
self.head_0 = SPADEResnetBlock(16 * nf, 16 * nf)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
self.up_2 = SPADEResnetBlock(4 * nf, nf)
#self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)
final_nc = nf
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def forward(self, x, input):
seg = input
x = self.head_0(x, seg)
x = self.up(x)
x = self.G_middle_0(x, seg)
x = self.G_middle_1(x, seg)
x = self.up(x)
x = self.up_0(x, seg)
x = self.up(x)
x = self.up_1(x, seg)
x = self.up(x)
x = self.up_2(x, seg)
#x = self.up(x)
#x = self.up_3(x, seg)
x = self.conv_img(F.leaky_relu(x, 2e-1))
return x
class SPADE(nn.Module):
def __init__(self, norm_nc, label_nc):
super().__init__()
ks = 3
self.param_free_norm = nn.InstanceNorm2d(norm_nc, affine=False)
# The dimension of the intermediate embedding space. Yes, hardcoded.
nhidden = 128
pw = ks // 2
self.mlp_shared = nn.Sequential(
nn.Conv2d(label_nc, nhidden, kernel_size=ks, padding=pw),
nn.ReLU()
)
self.mlp_gamma = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
self.mlp_beta = nn.Conv2d(nhidden, norm_nc, kernel_size=ks, padding=pw)
def forward(self, x, segmap):
# Part 1. generate parameter-free normalized activations
normalized = self.param_free_norm(x)
# Part 2. produce scaling and bias conditioned on semantic map
#segmap = F.interpolate(segmap, size=x.size()[2:], mode='nearest')
segmap = F.interpolate(segmap, size=x.size()[2:], mode='bilinear', align_corners = False)
actv = self.mlp_shared(segmap)
gamma = self.mlp_gamma(actv)
beta = self.mlp_beta(actv)
# apply scale and bias
out = normalized * (1 + gamma) + beta
return out
class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout):
super().__init__()
# Attributes
self.learned_shortcut = (fin != fout)
fmiddle = min(fin, fout)
# create conv layers
self.conv_0 = nn.Conv2d(fin, fmiddle, kernel_size=3, padding=1)
self.conv_1 = nn.Conv2d(fmiddle, fout, kernel_size=3, padding=1)
if self.learned_shortcut:
self.conv_s = nn.Conv2d(fin, fout, kernel_size=1, bias=False)
# define normalization layers
self.norm_0 = SPADE(fin, 256)
self.norm_1 = SPADE(fmiddle, 256)
if self.learned_shortcut:
self.norm_s = SPADE(fin, 256)
# note the resnet block with SPADE also takes in |seg|,
# the semantic segmentation map as input
def forward(self, x, seg):
x_s = self.shortcut(x, seg)
dx = self.conv_0(self.actvn(self.norm_0(x, seg)))
dx = self.conv_1(self.actvn(self.norm_1(dx, seg)))
out = x_s + dx
return out
def shortcut(self, x, seg):
if self.learned_shortcut:
x_s = self.conv_s(self.norm_s(x, seg))
else:
x_s = x
return x_s
def actvn(self, x):
return F.leaky_relu(x, 2e-1)
def get_edges(sp_label, sp_num):
# This function returns a (hw) * (hw) matrix N.
# If Nij = 1, then superpixel i and j are neighbors
# Otherwise Nij = 0.
top = sp_label[:, :, :-1, :] - sp_label[:, :, 1:, :]
left = sp_label[:, :, :, :-1] - sp_label[:, :, :, 1:]
top_left = sp_label[:, :, :-1, :-1] - sp_label[:, :, 1:, 1:]
top_right = sp_label[:, :, :-1, 1:] - sp_label[:, :, 1:, :-1]
n_affs = []
edge_indices = []
for i in range(sp_label.shape[0]):
# change to torch.ones below to include self-loop in graph
n_aff = torch.zeros(sp_num, sp_num).unsqueeze(0).to(sp_label.device)
# top/bottom
top_i = top[i].squeeze()
x, y = torch.nonzero(top_i, as_tuple = True)
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x+1, y].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# left/right
left_i = left[i].squeeze()
try:
x, y = torch.nonzero(left_i, as_tuple = True)
except:
import pdb; pdb.set_trace()
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x, y+1].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# top left
top_left_i = top_left[i].squeeze()
x, y = torch.nonzero(top_left_i, as_tuple = True)
sp1 = sp_label[i, :, x, y].squeeze().long()
sp2 = sp_label[i, :, x+1, y+1].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
# top right
top_right_i = top_right[i].squeeze()
x, y = torch.nonzero(top_right_i, as_tuple = True)
sp1 = sp_label[i, :, x, y+1].squeeze().long()
sp2 = sp_label[i, :, x+1, y].squeeze().long()
n_aff[:, sp1, sp2] = 1
n_aff[:, sp2, sp1] = 1
n_affs.append(n_aff)
edge_index = torch.stack(torch.nonzero(n_aff.squeeze(), as_tuple=True))
edge_indices.append(edge_index.to(sp_label.device))
return edge_indices, torch.cat(n_affs)
def enforce_connectivity(segs, H, W, sp_num = 196, min_size = None, max_size = None):
rets = []
for i in range(segs.shape[0]):
seg = segs[i]
seg = seg.squeeze().cpu().numpy()
segment_size = H * W / sp_num
if min_size is None:
min_size = int(0.1 * segment_size)
if max_size is None:
max_size = int(1000.0 * segment_size)
seg = _enforce_label_connectivity_cython(seg[None], min_size, max_size)[0]
seg = torch.from_numpy(seg).unsqueeze(0).unsqueeze(0)
rets.append(seg)
return torch.cat(rets)