sunshineatnoon
new_model
cde2253
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision.transforms.functional as TF
from .taming_blocks import Encoder
from .loss import styleLossMaskv3
from .nnutils import SPADEResnetBlock, get_edges, initWave
from libs.nnutils import poolfeat, upfeat
from libs.utils import label2one_hot_torch
from .meanshift_utils import meanshift_cluster, meanshift_assign
from swapae.models.networks.stylegan2_layers import ConvLayer
from torch_geometric.nn import GCNConv
from torch_geometric.utils import softmax
class GCN(nn.Module):
def __init__(self, n_cluster, temperature = 1, hidden_dim = 256):
super().__init__()
self.gcnconv1 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
self.gcnconv2 = GCNConv(hidden_dim, hidden_dim, add_self_loops = True)
self.pool1 = nn.Sequential(nn.Conv2d(hidden_dim, n_cluster, 3, 1, 1))
self.temperature = temperature
def compute_edge_score_softmax(self, raw_edge_score, edge_index, num_nodes):
return softmax(raw_edge_score, edge_index[1], num_nodes=num_nodes)
def compute_edge_weight(self, node_feature, edge_index):
src_feat = torch.gather(node_feature, 0, edge_index[0].unsqueeze(1).repeat(1, node_feature.shape[1]))
tgt_feat = torch.gather(node_feature, 0, edge_index[1].unsqueeze(1).repeat(1, node_feature.shape[1]))
raw_edge_weight = nn.CosineSimilarity(dim=1, eps=1e-6)(src_feat, tgt_feat)
edge_weight = self.compute_edge_score_softmax(raw_edge_weight, edge_index, node_feature.shape[0])
return raw_edge_weight.squeeze(), edge_weight.squeeze()
def forward(self, sp_code, slic, clustering = False):
edges, aff = get_edges(torch.argmax(slic, dim = 1).unsqueeze(1), sp_code.shape[1])
prop_code = []
sp_assign = []
edge_weights = []
conv_feats = []
for i in range(sp_code.shape[0]):
# compute edge weight
edge_index = edges[i]
raw_edge_weight, edge_weight = self.compute_edge_weight(sp_code[i], edge_index)
feat = self.gcnconv1(sp_code[i], edge_index, edge_weight = edge_weight)
raw_edge_weight, edge_weight = self.compute_edge_weight(feat, edge_index)
edge_weights.append(raw_edge_weight)
feat = F.leaky_relu(feat, 0.2)
feat = self.gcnconv2(feat, edge_index, edge_weight = edge_weight)
# maybe clustering
conv_feat = upfeat(feat, slic[i:i+1])
conv_feats.append(conv_feat)
if not clustering:
feat = conv_feat
pred_mask = slic[i:i+1]
else:
pred_mask = self.pool1(conv_feat)
# enforce pixels belong to the same superpixel to have same grouping label
pred_mask = upfeat(poolfeat(pred_mask, slic[i:i+1]), slic[i:i+1])
s_ = F.softmax(pred_mask * self.temperature, dim = 1)
# compute texture code w.r.t grouping
pool_feat = poolfeat(conv_feat, s_, avg = True)
feat = upfeat(pool_feat, s_)
prop_code.append(feat)
sp_assign.append(pred_mask)
prop_code = torch.cat(prop_code)
conv_feats = torch.cat(conv_feats)
return prop_code, torch.cat(sp_assign), conv_feats
class SPADEGenerator(nn.Module):
def __init__(self, in_dim, hidden_dim):
super().__init__()
nf = hidden_dim // 16
self.head_0 = SPADEResnetBlock(in_dim, 16 * nf)
self.G_middle_0 = SPADEResnetBlock(16 * nf, 16 * nf)
self.G_middle_1 = SPADEResnetBlock(16 * nf, 16 * nf)
self.up_0 = SPADEResnetBlock(16 * nf, 8 * nf)
self.up_1 = SPADEResnetBlock(8 * nf, 4 * nf)
self.up_2 = SPADEResnetBlock(4 * nf, 2 * nf)
self.up_3 = SPADEResnetBlock(2 * nf, 1 * nf)
final_nc = nf
self.conv_img = nn.Conv2d(final_nc, 3, 3, padding=1)
self.up = nn.Upsample(scale_factor=2)
def forward(self, sine_wave, texon):
x = self.head_0(sine_wave, texon)
x = self.up(x)
x = self.G_middle_0(x, texon)
x = self.G_middle_1(x, texon)
x = self.up(x)
x = self.up_0(x, texon)
x = self.up(x)
x = self.up_1(x, texon)
#x = self.up(x)
x = self.up_2(x, texon)
#x = self.up(x)
x = self.up_3(x, texon)
x = self.conv_img(F.leaky_relu(x, 2e-1))
return x
class Waver(nn.Module):
def __init__(self, tex_code_dim, zPeriodic):
super(Waver, self).__init__()
K = tex_code_dim
layers = [nn.Conv2d(tex_code_dim, K, 1)]
layers += [nn.ReLU(True)]
layers += [nn.Conv2d(K, 2 * zPeriodic, 1)]
self.learnedWN = nn.Sequential(*layers)
self.waveNumbers = initWave(zPeriodic)
def forward(self, GLZ=None):
return (self.waveNumbers.to(GLZ.device) + self.learnedWN(GLZ))
class AE(nn.Module):
def __init__(self, args, **ignore_kwargs):
super(AE, self).__init__()
# encoder & decoder
self.enc = Encoder(ch=64, out_ch=3, ch_mult=[1,2,4,8], num_res_blocks=1, attn_resolutions=[],
in_channels=3, resolution=args.crop_size, z_channels=args.hidden_dim, double_z=False)
self.G = SPADEGenerator(args.spatial_code_dim + 32, args.hidden_dim)
self.add_module(
"ToTexCode",
nn.Sequential(
ConvLayer(args.hidden_dim, args.hidden_dim, kernel_size=3, activate=True, bias=True),
ConvLayer(args.hidden_dim, args.tex_code_dim, kernel_size=3, activate=True, bias=True),
ConvLayer(args.tex_code_dim, args.hidden_dim, kernel_size=1, activate=False, bias=False)
)
)
self.gcn = GCN(n_cluster = args.n_cluster, temperature = args.temperature, hidden_dim = args.hidden_dim)
self.add_gcn_epoch = args.add_gcn_epoch
self.add_clustering_epoch = args.add_clustering_epoch
self.add_texture_epoch = args.add_texture_epoch
self.patch_size = args.patch_size
self.style_loss = styleLossMaskv3(device = args.device)
self.sine_wave_dim = args.spatial_code_dim
self.noise_dim = 32
self.spatial_code_dim = args.spatial_code_dim
# inpainting network
if args.spatial_code_dim > 0:
self.learnedWN = Waver(args.hidden_dim, zPeriodic = args.spatial_code_dim)
self.add_module(
"Amplitude",
nn.Sequential(
nn.Conv2d(args.hidden_dim, args.hidden_dim//2, 1, 1, 0),
nn.Conv2d(args.hidden_dim//2, args.hidden_dim//4, 1, 1, 0),
nn.Conv2d(args.hidden_dim//4, args.spatial_code_dim, 1, 1, 0)
)
)
self.bandwidth = 3.0
def sample_patch_from_mask(self, mask, patch_num = 10, patch_size = 64):
"""
- Sample `patch_num` patches of size `patch_size*patch_size` w.r.t given mask
"""
nonzeros = torch.nonzero(mask.view(-1)).squeeze()
n = len(nonzeros)
xys = []
imgH, imgW = mask.shape
half_patch = patch_size // 2
iter_num = 0
while len(xys) < patch_num:
id = (torch.ones(n)*1.0/n).multinomial(num_samples=1, replacement=False)
rx = nonzeros[id] // imgW
ry = nonzeros[id] % imgW
top = max(0, rx - half_patch)
bot = min(imgH, rx + half_patch)
left = max(0, ry - half_patch)
right = min(imgW, ry + half_patch)
patch_mask = mask[top:bot, left:right]
if torch.sum(patch_mask) / (patch_size ** 2) > 0.5 or iter_num > 20:
xys.append([top, bot, left, right])
iter_num += 1
return xys
def get_sine_wave(self, GL, offset_mode = 'rec'):
imgH, imgW = GL.shape[-2]//8, GL.shape[-1] // 8
GL = F.interpolate(GL, size = (imgH, imgW), mode = 'nearest')
xv, yv = np.meshgrid(np.arange(imgH), np.arange(imgW),indexing='ij')
c = torch.FloatTensor(np.concatenate([xv[np.newaxis], yv[np.newaxis]], 0)[np.newaxis])
c = c.to(GL.device)
# c: 1, 2, 28, 28
c = c.repeat(GL.shape[0], self.sine_wave_dim, 1, 1)
# c: 1, 64, 28, 28
period = self.learnedWN(GL)
# period: 1, 64, 28, 28
raw = period * c
# random offset
roffset = torch.zeros((GL.shape[0], self.sine_wave_dim, 1, 1)).to(GL.device).uniform_(-1, 1) * 6.28
roffset = roffset.repeat(1, 1, imgH, imgW)
rwave = torch.sin(raw[:, ::2] + raw[:, 1::2] + roffset)
# zero offset
zwave = torch.sin(raw[:, ::2] + raw[:, 1::2])
A = self.Amplitude(GL)
A = torch.sigmoid(A)
wave = torch.cat((zwave, rwave)) * A.repeat(2, 1, 1, 1)
return wave
def syn_tex(self, tex_code, mask, imgH, imgW, offset_mode = 'rec', tex_idx = None):
# synthesize all textures
# spatial: B x 256 x 14 x 14
# tex_code: B x N x 256
B, N, _ = tex_code.shape
H = imgH // 8
W = imgW // 8
# randomly sample a texture and synthesize it
# throw away small texture segments
areas = torch.sum(mask, dim=(2, 3))
valid_idxs = torch.nonzero(areas[0] / (imgH * imgW) > 0.01).squeeze(-1)
if tex_idx is None or tex_idx >= tex_code.shape[1]:
tex_idx = valid_idxs[torch.multinomial(areas[0, valid_idxs], 1).squeeze()]
else:
sorted_list = torch.argsort(areas, dim = 1, descending = True)
tex_idx = sorted_list[0, tex_idx]
sampled_code = tex_code[:, tex_idx, :]
rec_tex = sampled_code.view(1, -1, 1, 1).repeat(1, 1, imgH, imgW)
# Decoder: Spatial & Texture code -> Image
if self.noise_dim == 0:
dec_input = self.get_sine_wave(rec_tex, offset_mode)
elif self.spatial_code_dim == 0:
dec_input = torch.randn(rec_tex.shape[0], self.noise_dim, H, W).to(tex_code.device)
else:
sine_wave = self.get_sine_wave(rec_tex, offset_mode)
noise = torch.randn(sine_wave.shape[0], self.noise_dim, H, W).to(tex_code.device)
dec_input = torch.cat((sine_wave, noise), dim = 1)
tex_syn = self.G(dec_input, rec_tex.repeat(dec_input.shape[0], 1, 1, 1))
return tex_syn, tex_idx
def sample_tex_patches(self, tex_idx, rgb_img, rep_rec, mask, patch_num = 10):
patches = []
masks = []
patch_masks = []
# sample patches from input image and reconstruction
for i in range(rgb_img.shape[0]):
# WARNING: : This only works for batch_size = 1 for now
maski = mask[i, tex_idx]
masks.append(maski.unsqueeze(0))
xys = self.sample_patch_from_mask(maski, patch_num = patch_num, patch_size = self.patch_size)
# sample 10 patches from input image & reconstruction w.r.t group mask
for k in range(patch_num):
top, bot, left, right = xys[k]
patch_ = rgb_img[i, :, top:bot, left:right]
patch_mask_ = maski[top:bot, left:right]
# In case the patch is on the boundary and smaller than patch_size
# We put the patch at some random place of a black image
h, w = patch_.shape[-2:]
x = 0; y = 0
if h < self.patch_size:
x = np.random.randint(0, self.patch_size - h)
if w < self.patch_size:
y = np.random.randint(0, self.patch_size - w)
patch = torch.zeros(1, 3, self.patch_size, self.patch_size).to(patch_.device)
patch_mask = torch.zeros(1, 1, self.patch_size, self.patch_size).to(patch_.device)
patch[:, :, x:x+h, y:y+w] = patch_
patch_mask[:, :, x:x+h, y:y+w] = patch_mask_
patches.append(patch)
patch_masks.append(patch_mask)
patches = torch.cat(patches)
masks = torch.stack(masks)
patch_masks = torch.cat(patch_masks)
# sample patches from synthesized texture
tex_patch_size = self.patch_size
rep_patches = []
for k in range(patch_num):
i, j, h, w = transforms.RandomCrop.get_params(rep_rec, output_size=(tex_patch_size, tex_patch_size))
rep_rec_patch = TF.crop(rep_rec, i, j, h, w)
rep_patches.append(rep_rec_patch)
rep_patches = torch.stack(rep_patches, dim = 1)
rep_patches = rep_patches.view(-1, 3, tex_patch_size, tex_patch_size)
return masks, patch_masks, patches, rep_patches
def forward(self, rgb_img, slic, epoch = 0, test_time = False, test = False, tex_idx = None):
#self.patch_size = np.random.randint(64, 160)
B, _, imgH, imgW = rgb_img.shape
outputs = {}
rec_feat_list = []
seg_map = [torch.argmax(slic.cpu(), dim = 1)]
# Encoder: img (B, 3, H, W) -> feature (B, C, imgH//8, imgW//8)
conv_feat, layer_feats = self.enc(rgb_img)
B, C, H, W = conv_feat.shape
# Texture code for each superpixel
tex_code = self.ToTexCode(conv_feat)
code = F.interpolate(tex_code, size = (imgH, imgW), mode = 'bilinear', align_corners = False)
pool_code = poolfeat(code, slic, avg = True)
if epoch >= self.add_gcn_epoch:
prop_code, sp_assign, conv_feats = self.gcn(pool_code, slic, (self.add_clustering_epoch <= epoch))
softmax = F.softmax(sp_assign * self.gcn.temperature, dim = 1)
rec_feat_list.append(prop_code)
seg_map = [torch.argmax(sp_assign.cpu(), dim = 1)]
else:
rec_code = upfeat(pool_code, slic)
rec_feat_list.append(rec_code)
softmax = slic
# Texture synthesis
if epoch >= self.add_texture_epoch:
sp_feat = poolfeat(conv_feats, slic, avg = True).squeeze(0)
pts = meanshift_cluster(sp_feat, self.bandwidth, meanshift_step = 15)[-1]
with torch.no_grad():
sp_assign, _ = meanshift_assign(pts, self.bandwidth)
sp_assign = torch.tensor(sp_assign).unsqueeze(-1).to(slic.device).float()
sp_assign = upfeat(sp_assign, slic)
seg = label2one_hot_torch(sp_assign, C = sp_assign.max().long() + 1)
seg_map = [torch.argmax(seg.cpu(), dim = 1)]
# texture code for each connected group
tex_seg = poolfeat(conv_feats, seg, avg = True)
if test:
rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 564, 564, tex_idx = tex_idx)
#rep_rec, tex_idx = self.syn_tex(tex_seg, seg, 1024, 1024, tex_idx = tex_idx)
else:
rep_rec, tex_idx = self.syn_tex(tex_seg, seg, imgH, imgW, tex_idx = tex_idx)
rep_rec = (rep_rec + 1) / 2.0
rgb_img = (rgb_img + 1) / 2.0
# sample patches from input image, reconstruction & synthesized texture
# zero offset
zmasks, zpatch_masks, zpatches, zrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[:1], seg)
# random offset
rmasks, rpatch_masks, rpatches, rrep_patches = self.sample_tex_patches(tex_idx, rgb_img, rep_rec[1:], seg)
masks = torch.cat((zmasks, rmasks))
patch_masks = torch.cat((zpatch_masks, rpatch_masks))
patches = torch.cat((zpatches, rpatches))
rep_patches = torch.cat((zrep_patches, rrep_patches))
# Gram matrix matching loss between:
# - patches from synthesized texture v.s. patches from input image
# - patches from reconstruction v.s. patches from input image
outputs['style_loss'] = self.style_loss.forward_patch_img(rep_patches, rgb_img.repeat(2, 1, 1, 1), masks)
outputs['rep_rec'] = rep_rec
outputs['masks'] = masks
outputs['patches'] = patches.view(-1, 3, self.patch_size, self.patch_size)
outputs['patch_masks'] = patch_masks
outputs['rep_patches'] = rep_patches * patch_masks + patches * (1 - patch_masks)
outputs['gt'] = rgb_img
bp_tex = rep_rec[:1, :, :imgH, :imgW] * masks[:1] + rgb_img * (1 - masks[:1])
outputs['rec'] = bp_tex
outputs['HA'] = torch.cat(seg_map)
return outputs