|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from models.modelUtils import ChebConv, Pool, residualBlock |
|
import torchvision.ops.roi_align as roi_align |
|
|
|
import numpy as np |
|
|
|
class EncoderConv(nn.Module): |
|
def __init__(self, latents = 64, hw = 32): |
|
super(EncoderConv, self).__init__() |
|
|
|
self.latents = latents |
|
self.c = 4 |
|
|
|
self.size = self.c * np.array([2,4,8,16,32], dtype = np.intc) |
|
|
|
self.maxpool = nn.MaxPool2d(2) |
|
|
|
self.dconv_down1 = residualBlock(1, self.size[0]) |
|
self.dconv_down2 = residualBlock(self.size[0], self.size[1]) |
|
self.dconv_down3 = residualBlock(self.size[1], self.size[2]) |
|
self.dconv_down4 = residualBlock(self.size[2], self.size[3]) |
|
self.dconv_down5 = residualBlock(self.size[3], self.size[4]) |
|
self.dconv_down6 = residualBlock(self.size[4], self.size[4]) |
|
|
|
self.fc_mu = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents) |
|
self.fc_logvar = nn.Linear(in_features=self.size[4]*hw*hw, out_features=self.latents) |
|
|
|
def forward(self, x): |
|
x = self.dconv_down1(x) |
|
x = self.maxpool(x) |
|
|
|
x = self.dconv_down2(x) |
|
x = self.maxpool(x) |
|
|
|
conv3 = self.dconv_down3(x) |
|
x = self.maxpool(conv3) |
|
|
|
conv4 = self.dconv_down4(x) |
|
x = self.maxpool(conv4) |
|
|
|
conv5 = self.dconv_down5(x) |
|
x = self.maxpool(conv5) |
|
|
|
conv6 = self.dconv_down6(x) |
|
|
|
x = conv6.view(conv6.size(0), -1) |
|
|
|
x_mu = self.fc_mu(x) |
|
x_logvar = self.fc_logvar(x) |
|
|
|
return x_mu, x_logvar, conv6, conv5 |
|
|
|
|
|
class SkipBlock(nn.Module): |
|
def __init__(self, in_filters, window): |
|
super(SkipBlock, self).__init__() |
|
|
|
self.window = window |
|
self.graphConv_pre = ChebConv(in_filters, 2, 1, bias = False) |
|
|
|
def lookup(self, pos, layer, salida = (1,1)): |
|
B = pos.shape[0] |
|
N = pos.shape[1] |
|
F = layer.shape[1] |
|
h = layer.shape[-1] |
|
|
|
|
|
pos = pos * h |
|
|
|
_x1 = (self.window[0] // 2) * 1.0 |
|
_x2 = (self.window[0] // 2 + 1) * 1.0 |
|
_y1 = (self.window[1] // 2) * 1.0 |
|
_y2 = (self.window[1] // 2 + 1) * 1.0 |
|
|
|
boxes = [] |
|
for batch in range(0, B): |
|
x1 = pos[batch,:,0].reshape(-1, 1) - _x1 |
|
x2 = pos[batch,:,0].reshape(-1, 1) + _x2 |
|
y1 = pos[batch,:,1].reshape(-1, 1) - _y1 |
|
y2 = pos[batch,:,1].reshape(-1, 1) + _y2 |
|
|
|
aux = torch.cat([x1, y1, x2, y2], axis = 1) |
|
boxes.append(aux) |
|
|
|
skip = roi_align(layer, boxes, output_size = salida, aligned=True) |
|
vista = skip.view([B, N, -1]) |
|
|
|
return vista |
|
|
|
def forward(self, x, adj, conv_layer): |
|
pos = self.graphConv_pre(x, adj) |
|
skip = self.lookup(pos, conv_layer) |
|
|
|
return torch.cat((x, skip, pos), axis = 2), pos |
|
|
|
|
|
class Hybrid(nn.Module): |
|
def __init__(self, config, downsample_matrices, upsample_matrices, adjacency_matrices): |
|
super(Hybrid, self).__init__() |
|
|
|
self.config = config |
|
hw = config['inputsize'] // 32 |
|
self.z = config['latents'] |
|
self.encoder = EncoderConv(latents = self.z, hw = hw) |
|
|
|
self.downsample_matrices = downsample_matrices |
|
self.upsample_matrices = upsample_matrices |
|
self.adjacency_matrices = adjacency_matrices |
|
self.kld_weight = 1e-5 |
|
|
|
n_nodes = config['n_nodes'] |
|
self.filters = config['filters'] |
|
self.K = 6 |
|
self.window = (3,3) |
|
|
|
|
|
outshape = self.filters[-1] * n_nodes[-1] |
|
self.dec_lin = torch.nn.Linear(self.z, outshape) |
|
|
|
self.normalization2u = torch.nn.InstanceNorm1d(self.filters[1]) |
|
self.normalization3u = torch.nn.InstanceNorm1d(self.filters[2]) |
|
self.normalization4u = torch.nn.InstanceNorm1d(self.filters[3]) |
|
self.normalization5u = torch.nn.InstanceNorm1d(self.filters[4]) |
|
self.normalization6u = torch.nn.InstanceNorm1d(self.filters[5]) |
|
|
|
outsize1 = self.encoder.size[4] |
|
outsize2 = self.encoder.size[4] |
|
|
|
|
|
self.graphConv_up6 = ChebConv(self.filters[6], self.filters[5], self.K) |
|
self.graphConv_up5 = ChebConv(self.filters[5], self.filters[4], self.K) |
|
|
|
self.SC_1 = SkipBlock(self.filters[4], self.window) |
|
|
|
self.graphConv_up4 = ChebConv(self.filters[4] + outsize1 + 2, self.filters[3], self.K) |
|
self.graphConv_up3 = ChebConv(self.filters[3], self.filters[2], self.K) |
|
|
|
self.SC_2 = SkipBlock(self.filters[2], self.window) |
|
|
|
self.graphConv_up2 = ChebConv(self.filters[2] + outsize2 + 2, self.filters[1], self.K) |
|
self.graphConv_up1 = ChebConv(self.filters[1], self.filters[0], 1, bias = False) |
|
|
|
self.pool = Pool() |
|
|
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
torch.nn.init.normal_(self.dec_lin.weight, 0, 0.1) |
|
|
|
|
|
def sampling(self, mu, log_var): |
|
std = torch.exp(0.5*log_var) |
|
eps = torch.randn_like(std) |
|
return eps.mul(std).add_(mu) |
|
|
|
|
|
def forward(self, x): |
|
self.mu, self.log_var, conv6, conv5 = self.encoder(x) |
|
|
|
if self.training: |
|
z = self.sampling(self.mu, self.log_var) |
|
else: |
|
z = self.mu |
|
|
|
x = self.dec_lin(z) |
|
x = F.relu(x) |
|
|
|
x = x.reshape(x.shape[0], -1, self.filters[-1]) |
|
|
|
x = self.graphConv_up6(x, self.adjacency_matrices[5]._indices()) |
|
x = self.normalization6u(x) |
|
x = F.relu(x) |
|
|
|
x = self.graphConv_up5(x, self.adjacency_matrices[4]._indices()) |
|
x = self.normalization5u(x) |
|
x = F.relu(x) |
|
|
|
x, pos1 = self.SC_1(x, self.adjacency_matrices[3]._indices(), conv6) |
|
|
|
x = self.graphConv_up4(x, self.adjacency_matrices[3]._indices()) |
|
x = self.normalization4u(x) |
|
x = F.relu(x) |
|
|
|
x = self.pool(x, self.upsample_matrices[0]) |
|
|
|
x = self.graphConv_up3(x, self.adjacency_matrices[2]._indices()) |
|
x = self.normalization3u(x) |
|
x = F.relu(x) |
|
|
|
x, pos2 = self.SC_2(x, self.adjacency_matrices[1]._indices(), conv5) |
|
|
|
x = self.graphConv_up2(x, self.adjacency_matrices[1]._indices()) |
|
x = self.normalization2u(x) |
|
x = F.relu(x) |
|
|
|
x = self.graphConv_up1(x, self.adjacency_matrices[0]._indices()) |
|
|
|
return x, pos1, pos2 |