|
import torch |
|
import torch.nn as nn |
|
import torch.nn.parallel |
|
import torch.utils.data |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch.nn import Parameter |
|
|
|
from torch_geometric.nn.dense.linear import Linear |
|
from torch_geometric.nn.conv import MessagePassing |
|
from torch_geometric.utils import softmax |
|
|
|
from torch_geometric.nn.inits import glorot, zeros |
|
|
|
from torch_scatter import scatter |
|
from utils.utils import triplets,get_angle,GaussianSmearing |
|
from torch.nn import ModuleList |
|
from math import pi as PI |
|
import math |
|
|
|
""" |
|
The theory based Grid cell spatial relation encoder, |
|
See https://openreview.net/forum?id=Syx0Mh05YQ |
|
Learning Grid Cells as Vector Representation of Self-Position Coupled with Matrix Representation of Self-Motion |
|
""" |
|
def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius): |
|
if freq_init == "random": |
|
|
|
|
|
freq_list = np.random.random(size=[frequency_num]) * max_radius |
|
elif freq_init == "geometric": |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) / |
|
(frequency_num*1.0 - 1)) |
|
|
|
timescales = min_radius * np.exp( |
|
np.arange(frequency_num).astype(float) * log_timescale_increment) |
|
|
|
freq_list = 1.0/timescales |
|
|
|
return freq_list |
|
class TheoryGridCellSpatialRelationEncoder(nn.Module): |
|
""" |
|
Given a list of (deltaX,deltaY), encode them using the position encoding function |
|
|
|
""" |
|
def __init__(self, spa_embed_dim, coord_dim = 2, frequency_num = 16, |
|
max_radius = 10000, min_radius = 1000, freq_init = "geometric", ffn = None): |
|
""" |
|
Args: |
|
spa_embed_dim: the output spatial relation embedding dimention |
|
coord_dim: the dimention of space, 2D, 3D, or other |
|
frequency_num: the number of different sinusoidal with different frequencies/wavelengths |
|
max_radius: the largest context radius this model can handle |
|
""" |
|
super(TheoryGridCellSpatialRelationEncoder, self).__init__() |
|
self.frequency_num = frequency_num |
|
self.coord_dim = coord_dim |
|
self.max_radius = max_radius |
|
self.min_radius = min_radius |
|
self.spa_embed_dim = spa_embed_dim |
|
self.freq_init = freq_init |
|
|
|
|
|
self.cal_freq_list() |
|
self.cal_freq_mat() |
|
|
|
|
|
self.unit_vec1 = np.asarray([1.0, 0.0]) |
|
self.unit_vec2 = np.asarray([-1.0/2.0, math.sqrt(3)/2.0]) |
|
self.unit_vec3 = np.asarray([-1.0/2.0, -math.sqrt(3)/2.0]) |
|
|
|
|
|
self.input_embed_dim = self.cal_input_dim() |
|
self.ffn = ffn |
|
|
|
def cal_freq_list(self): |
|
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) |
|
|
|
def cal_freq_mat(self): |
|
|
|
freq_mat = np.expand_dims(self.freq_list, axis = 1) |
|
|
|
self.freq_mat = np.repeat(freq_mat, 6, axis = 1) |
|
|
|
def cal_input_dim(self): |
|
|
|
return int(6 * self.frequency_num) |
|
|
|
|
|
def make_input_embeds(self, coords): |
|
if type(coords) == np.ndarray: |
|
assert self.coord_dim == np.shape(coords)[2] |
|
coords = list(coords) |
|
elif type(coords) == list: |
|
assert self.coord_dim == len(coords[0][0]) |
|
elif type(coords) == torch.Tensor: |
|
assert self.coord_dim == (coords.shape)[2] |
|
coords=coords.detach().cpu().numpy() |
|
else: |
|
raise Exception("Unknown coords data type for GridCellSpatialRelationEncoder") |
|
|
|
|
|
coords_mat = np.asarray(coords).astype(float) |
|
batch_size = coords_mat.shape[0] |
|
num_context_pt = coords_mat.shape[1] |
|
|
|
|
|
|
|
angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis = -1) |
|
|
|
angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis = -1) |
|
|
|
angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis = -1) |
|
|
|
|
|
angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis = -1) |
|
|
|
angle_mat = np.expand_dims(angle_mat, axis = -2) |
|
|
|
angle_mat = np.repeat(angle_mat, self.frequency_num, axis = -2) |
|
|
|
angle_mat = angle_mat * self.freq_mat |
|
|
|
spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1)) |
|
|
|
|
|
|
|
|
|
spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) |
|
spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) |
|
|
|
return spr_embeds |
|
|
|
|
|
def forward(self, coords): |
|
""" |
|
Given a list of coords (deltaX, deltaY), give their spatial relation embedding |
|
Args: |
|
coords: a python list with shape (batch_size, num_context_pt, coord_dim) |
|
Return: |
|
sprenc: Tensor shape (batch_size, num_context_pt, spa_embed_dim) |
|
""" |
|
spr_embeds = self.make_input_embeds(coords) |
|
|
|
|
|
spr_embeds = torch.FloatTensor(spr_embeds) |
|
if self.ffn is not None: |
|
return self.ffn(spr_embeds) |
|
else: |
|
return spr_embeds |
|
theoryencoder=TheoryGridCellSpatialRelationEncoder(spa_embed_dim=8) |
|
|
|
class GFusion(nn.Module): |
|
def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,num_of_datasources=2,share=True,batchnorm="False"): |
|
super(GFusion,self).__init__() |
|
self.training=True |
|
self.h_channel = h_channel |
|
self.input_featuresize=input_featuresize |
|
self.localdepth = localdepth |
|
self.num_interactions=num_interactions |
|
self.finaldepth=finaldepth |
|
self.batchnorm = batchnorm |
|
self.activation=nn.ReLU() |
|
|
|
num_gaussians=(1,12) |
|
self.theta_expansion = GaussianSmearing(-PI, PI, num_gaussians[1]) |
|
self.mlps_list = ModuleList() |
|
if int(share[0])==1: |
|
mlp_geo = ModuleList() |
|
for i in range(self.localdepth): |
|
if i == 0: |
|
mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
|
else: |
|
mlp_geo.append(Linear(h_channel, h_channel)) |
|
if self.batchnorm == "True": |
|
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
|
mlp_geo.append(self.activation) |
|
for i in range(num_of_datasources): |
|
self.mlps_list.append(mlp_geo) |
|
else: |
|
for i in range(num_of_datasources): |
|
mlp_geo = ModuleList() |
|
for i in range(self.localdepth): |
|
if i == 0: |
|
mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
|
else: |
|
mlp_geo.append(Linear(h_channel, h_channel)) |
|
if self.batchnorm == "True": |
|
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
|
mlp_geo.append(self.activation) |
|
self.mlps_list.append(mlp_geo) |
|
self.mlps_list_backup = ModuleList() |
|
for i in range(num_of_datasources): |
|
mlp_geo = ModuleList() |
|
for i in range(self.localdepth): |
|
if i == 0: |
|
mlp_geo.append(Linear(4, h_channel)) |
|
else: |
|
mlp_geo.append(Linear(h_channel, h_channel)) |
|
if self.batchnorm == "True": |
|
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
|
mlp_geo.append(self.activation) |
|
self.mlps_list_backup.append(mlp_geo) |
|
self.translinear=Linear(input_featuresize+1, self.h_channel) |
|
self.interactions_list = ModuleList() |
|
if int(share[1])==1: |
|
interactions= ModuleList() |
|
for i in range(self.num_interactions): |
|
block = SPNN( |
|
in_ch=self.input_featuresize, |
|
hidden_channels=self.h_channel, |
|
activation=self.activation, |
|
finaldepth=self.finaldepth, |
|
batchnorm=self.batchnorm, |
|
num_input_geofeature=self.h_channel |
|
) |
|
interactions.append(block) |
|
for i in range(num_of_datasources): |
|
self.interactions_list.append(interactions) |
|
else: |
|
for i in range(num_of_datasources): |
|
interactions= ModuleList() |
|
for i in range(self.num_interactions): |
|
block = SPNN( |
|
in_ch=self.input_featuresize, |
|
hidden_channels=self.h_channel, |
|
activation=self.activation, |
|
finaldepth=self.finaldepth, |
|
batchnorm=self.batchnorm, |
|
num_input_geofeature=self.h_channel |
|
) |
|
interactions.append(block) |
|
self.interactions_list.append(interactions) |
|
self.finalMLP_list = ModuleList() |
|
if int(share[2])==1: |
|
finalMLP=ModuleList() |
|
for i in range(self.finaldepth + 1): |
|
finalMLP.append(Linear(self.h_channel, self.h_channel)) |
|
if self.batchnorm == "True": |
|
finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
|
finalMLP.append(self.activation) |
|
finalMLP.append(Linear(self.h_channel, 1)) |
|
for i in range(num_of_datasources): |
|
self.finalMLP_list.append(finalMLP) |
|
else: |
|
for i in range(num_of_datasources): |
|
finalMLP=ModuleList() |
|
for i in range(self.finaldepth + 1): |
|
finalMLP.append(Linear(self.h_channel, self.h_channel)) |
|
if self.batchnorm == "True": |
|
finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
|
finalMLP.append(self.activation) |
|
finalMLP.append(Linear(self.h_channel, 1)) |
|
self.finalMLP_list.append(finalMLP) |
|
self.reset_parameters() |
|
def reset_parameters(self): |
|
for i in range(len(self.mlps_list)): |
|
for lin in self.mlps_list[i]: |
|
if isinstance(lin, Linear): |
|
torch.nn.init.xavier_uniform_(lin.weight) |
|
lin.bias.data.fill_(0) |
|
for i in range(len(self.interactions_list)): |
|
for block in self.interactions_list[i]: |
|
block.reset_parameters() |
|
for finalMLP in self.finalMLP_list: |
|
for lin in finalMLP: |
|
if isinstance(lin, Linear): |
|
torch.nn.init.xavier_uniform_(lin.weight) |
|
lin.bias.data.fill_(0) |
|
|
|
def single_forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep,datasource_idx): |
|
distances={} |
|
thetas={} |
|
if edge_rep: |
|
i, j, k = edge_index_2rd |
|
distances[1]=(coords[edge_index[0]] - coords[edge_index[1]]).norm(p=2, dim=1) |
|
theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) |
|
v1 = torch.cross(F.pad(coords[j] - coords[i],(0,1)), F.pad(coords[k] - coords[j],(0,1)), dim=1)[...,2] |
|
flag = torch.sign((v1)) |
|
flag[flag==0]=-1 |
|
thetas[1] = scatter(theta_ijk*flag ,edx_2nd,dim=0,dim_size=edge_index.shape[1],reduce='min') |
|
thetas[1]=self.theta_expansion(thetas[1]) |
|
geo_encoding_1st=distances[1][:,None] |
|
geo_encoding_1st[geo_encoding_1st==0]=1E-10 |
|
geo_encoding_1st=torch.pow(geo_encoding_1st,-1) |
|
geo_encoding_2nd = thetas[1] |
|
geo_encoding=torch.cat([geo_encoding_1st,geo_encoding_2nd],dim=-1) |
|
else: |
|
|
|
|
|
|
|
coords_j = coords[edge_index[0]] |
|
coords_i = coords[edge_index[1]] |
|
geo_encoding=torch.cat([coords_j,coords_i],dim=-1) |
|
if edge_rep: |
|
for lin in self.mlps_list[datasource_idx]: |
|
geo_encoding=lin(geo_encoding) |
|
else: |
|
for lin in self.mlps_list_backup[datasource_idx]: |
|
geo_encoding=lin(geo_encoding) |
|
geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) |
|
node_feature=self.translinear(input_feature[:,:-2]) |
|
for interaction in self.interactions_list[datasource_idx]: |
|
node_feature = interaction(node_feature,geo_encoding,edge_index,is_source) |
|
return node_feature |
|
def forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep): |
|
outputs=[] |
|
for i in range(len(coords)): |
|
output=self.single_forward(coords[i],edge_index[i],edge_index_2rd[i], edx_2nd[i],batch[i],input_feature[i],is_source[i],edge_rep,i) |
|
for lin in self.finalMLP_list[i]: |
|
output=lin(output) |
|
outputs.append(output) |
|
return outputs |
|
|
|
class SPNN(torch.nn.Module): |
|
def __init__( |
|
self, |
|
in_ch, |
|
hidden_channels, |
|
activation=torch.nn.ReLU(), |
|
finaldepth=3, |
|
batchnorm="False", |
|
num_input_geofeature=13 |
|
): |
|
super(SPNN, self).__init__() |
|
self.activation = activation |
|
self.finaldepth = finaldepth |
|
self.batchnorm = batchnorm |
|
self.num_input_geofeature=num_input_geofeature |
|
self.att = Parameter(torch.Tensor(1, hidden_channels),requires_grad=True) |
|
|
|
self.WMLP = ModuleList() |
|
for i in range(self.finaldepth + 1): |
|
if i == 0: |
|
self.WMLP.append(Linear(hidden_channels*2+num_input_geofeature, hidden_channels)) |
|
else: |
|
self.WMLP.append(Linear(hidden_channels, hidden_channels)) |
|
if self.batchnorm == "True": |
|
self.WMLP.append(nn.BatchNorm1d(hidden_channels)) |
|
self.WMLP.append(self.activation) |
|
self.reset_parameters() |
|
|
|
def reset_parameters(self): |
|
for lin in self.WMLP: |
|
if isinstance(lin, Linear): |
|
torch.nn.init.xavier_uniform_(lin.weight) |
|
lin.bias.data.fill_(0) |
|
glorot(self.att) |
|
def forward(self, node_feature,geo_encoding,edge_index,is_source): |
|
j, i = edge_index |
|
input_feature=node_feature.clone() |
|
if node_feature is None: |
|
concatenated_vector = geo_encoding |
|
else: |
|
node_attr_0st = node_feature[i] |
|
node_attr_1st = node_feature[j] |
|
concatenated_vector = torch.cat( |
|
[ |
|
node_attr_0st, |
|
node_attr_1st, |
|
geo_encoding, |
|
], |
|
dim=-1, |
|
) |
|
x_i = concatenated_vector |
|
for lin in self.WMLP: |
|
x_i=lin(x_i) |
|
input_feature_j=input_feature[edge_index[0]] |
|
x_i = F.leaky_relu(x_i) |
|
alpha = F.leaky_relu(x_i * self.att).sum(dim=-1) |
|
alpha = softmax(alpha, edge_index[1]) |
|
|
|
message=input_feature_j * alpha.unsqueeze(-1) |
|
out_feature = scatter(message, edge_index[1], dim=0, reduce='add') |
|
out_feature=input_feature+out_feature |
|
|
|
return out_feature |
|
|
|
|
|
|