import torch |
import torch.nn as nn |
import torch.nn.parallel |
import torch.utils.data |
import numpy as np |
import torch.nn.functional as F |
from torch.nn import Parameter |
from torch_geometric.nn.dense.linear import Linear |
from torch_geometric.nn.conv import MessagePassing |
from torch_geometric.utils import softmax |
from torch_geometric.nn.inits import glorot, zeros |
from torch_scatter import scatter |
from utils.utils import triplets,get_angle,GaussianSmearing |
from torch.nn import ModuleList |
from math import pi as PI |
import math |
""" |
The theory based Grid cell spatial relation encoder, |
See https://openreview.net/forum?id=Syx0Mh05YQ |
Learning Grid Cells as Vector Representation of Self-Position Coupled with Matrix Representation of Self-Motion |
""" |
def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius): |
if freq_init == "random": |
freq_list = np.random.random(size=[frequency_num]) * max_radius |
elif freq_init == "geometric": |
log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) / |
(frequency_num*1.0 - 1)) |
timescales = min_radius * np.exp( |
np.arange(frequency_num).astype(float) * log_timescale_increment) |
freq_list = 1.0/timescales |
return freq_list |
class TheoryGridCellSpatialRelationEncoder(nn.Module): |
""" |
Given a list of (deltaX,deltaY), encode them using the position encoding function |
""" |
def __init__(self, spa_embed_dim, coord_dim = 2, frequency_num = 16, |
max_radius = 10000, min_radius = 1000, freq_init = "geometric", ffn = None): |
""" |
Args: |
spa_embed_dim: the output spatial relation embedding dimention |
coord_dim: the dimention of space, 2D, 3D, or other |
frequency_num: the number of different sinusoidal with different frequencies/wavelengths |
max_radius: the largest context radius this model can handle |
""" |
super(TheoryGridCellSpatialRelationEncoder, self).__init__() |
self.frequency_num = frequency_num |
self.coord_dim = coord_dim |
self.max_radius = max_radius |
self.min_radius = min_radius |
self.spa_embed_dim = spa_embed_dim |
self.freq_init = freq_init |
self.cal_freq_list() |
self.cal_freq_mat() |
self.unit_vec1 = np.asarray([1.0, 0.0]) |
self.unit_vec2 = np.asarray([-1.0/2.0, math.sqrt(3)/2.0]) |
self.unit_vec3 = np.asarray([-1.0/2.0, -math.sqrt(3)/2.0]) |
self.input_embed_dim = self.cal_input_dim() |
self.ffn = ffn |
def cal_freq_list(self): |
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius) |
def cal_freq_mat(self): |
freq_mat = np.expand_dims(self.freq_list, axis = 1) |
self.freq_mat = np.repeat(freq_mat, 6, axis = 1) |
def cal_input_dim(self): |
return int(6 * self.frequency_num) |
def make_input_embeds(self, coords): |
if type(coords) == np.ndarray: |
assert self.coord_dim == np.shape(coords)[2] |
coords = list(coords) |
elif type(coords) == list: |
assert self.coord_dim == len(coords[0][0]) |
elif type(coords) == torch.Tensor: |
assert self.coord_dim == (coords.shape)[2] |
coords=coords.detach().cpu().numpy() |
else: |
raise Exception("Unknown coords data type for GridCellSpatialRelationEncoder") |
coords_mat = np.asarray(coords).astype(float) |
batch_size = coords_mat.shape[0] |
num_context_pt = coords_mat.shape[1] |
angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis = -1) |
angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis = -1) |
angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis = -1) |
angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis = -1) |
angle_mat = np.expand_dims(angle_mat, axis = -2) |
angle_mat = np.repeat(angle_mat, self.frequency_num, axis = -2) |
angle_mat = angle_mat * self.freq_mat |
spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1)) |
spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) |
spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) |
return spr_embeds |
def forward(self, coords): |
""" |
Given a list of coords (deltaX, deltaY), give their spatial relation embedding |
Args: |
coords: a python list with shape (batch_size, num_context_pt, coord_dim) |
Return: |
sprenc: Tensor shape (batch_size, num_context_pt, spa_embed_dim) |
""" |
spr_embeds = self.make_input_embeds(coords) |
spr_embeds = torch.FloatTensor(spr_embeds) |
if self.ffn is not None: |
return self.ffn(spr_embeds) |
else: |
return spr_embeds |
theoryencoder=TheoryGridCellSpatialRelationEncoder(spa_embed_dim=8) |
class GFusion(nn.Module): |
def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,num_of_datasources=2,share=True,batchnorm="False"): |
super(GFusion,self).__init__() |
self.training=True |
self.h_channel = h_channel |
self.input_featuresize=input_featuresize |
self.localdepth = localdepth |
self.num_interactions=num_interactions |
self.finaldepth=finaldepth |
self.batchnorm = batchnorm |
self.activation=nn.ReLU() |
num_gaussians=(1,12) |
self.theta_expansion = GaussianSmearing(-PI, PI, num_gaussians[1]) |
self.mlps_list = ModuleList() |
if int(share[0])==1: |
mlp_geo = ModuleList() |
for i in range(self.localdepth): |
if i == 0: |
mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
else: |
mlp_geo.append(Linear(h_channel, h_channel)) |
if self.batchnorm == "True": |
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
mlp_geo.append(self.activation) |
for i in range(num_of_datasources): |
self.mlps_list.append(mlp_geo) |
else: |
for i in range(num_of_datasources): |
mlp_geo = ModuleList() |
for i in range(self.localdepth): |
if i == 0: |
mlp_geo.append(Linear(sum(num_gaussians), h_channel)) |
else: |
mlp_geo.append(Linear(h_channel, h_channel)) |
if self.batchnorm == "True": |
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
mlp_geo.append(self.activation) |
self.mlps_list.append(mlp_geo) |
self.mlps_list_backup = ModuleList() |
for i in range(num_of_datasources): |
mlp_geo = ModuleList() |
for i in range(self.localdepth): |
if i == 0: |
mlp_geo.append(Linear(4, h_channel)) |
else: |
mlp_geo.append(Linear(h_channel, h_channel)) |
if self.batchnorm == "True": |
mlp_geo.append(nn.BatchNorm1d(h_channel)) |
mlp_geo.append(self.activation) |
self.mlps_list_backup.append(mlp_geo) |
self.translinear=Linear(input_featuresize+1, self.h_channel) |
self.interactions_list = ModuleList() |
if int(share[1])==1: |
interactions= ModuleList() |
for i in range(self.num_interactions): |
block = SPNN( |
in_ch=self.input_featuresize, |
hidden_channels=self.h_channel, |
activation=self.activation, |
finaldepth=self.finaldepth, |
batchnorm=self.batchnorm, |
num_input_geofeature=self.h_channel |
) |
interactions.append(block) |
for i in range(num_of_datasources): |
self.interactions_list.append(interactions) |
else: |
for i in range(num_of_datasources): |
interactions= ModuleList() |
for i in range(self.num_interactions): |
block = SPNN( |
in_ch=self.input_featuresize, |
hidden_channels=self.h_channel, |
activation=self.activation, |
finaldepth=self.finaldepth, |
batchnorm=self.batchnorm, |
num_input_geofeature=self.h_channel |
) |
interactions.append(block) |
self.interactions_list.append(interactions) |
self.finalMLP_list = ModuleList() |
if int(share[2])==1: |
finalMLP=ModuleList() |
for i in range(self.finaldepth + 1): |
finalMLP.append(Linear(self.h_channel, self.h_channel)) |
if self.batchnorm == "True": |
finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
finalMLP.append(self.activation) |
finalMLP.append(Linear(self.h_channel, 1)) |
for i in range(num_of_datasources): |
self.finalMLP_list.append(finalMLP) |
else: |
for i in range(num_of_datasources): |
finalMLP=ModuleList() |
for i in range(self.finaldepth + 1): |
finalMLP.append(Linear(self.h_channel, self.h_channel)) |
if self.batchnorm == "True": |
finalMLP.append(nn.BatchNorm1d(self.h_channel)) |
finalMLP.append(self.activation) |
finalMLP.append(Linear(self.h_channel, 1)) |
self.finalMLP_list.append(finalMLP) |
self.reset_parameters() |
def reset_parameters(self): |
for i in range(len(self.mlps_list)): |
for lin in self.mlps_list[i]: |
if isinstance(lin, Linear): |
torch.nn.init.xavier_uniform_(lin.weight) |
lin.bias.data.fill_(0) |
for i in range(len(self.interactions_list)): |
for block in self.interactions_list[i]: |
block.reset_parameters() |
for finalMLP in self.finalMLP_list: |
for lin in finalMLP: |
if isinstance(lin, Linear): |
torch.nn.init.xavier_uniform_(lin.weight) |
lin.bias.data.fill_(0) |
def single_forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep,datasource_idx): |
distances={} |
thetas={} |
if edge_rep: |
i, j, k = edge_index_2rd |
distances[1]=(coords[edge_index[0]] - coords[edge_index[1]]).norm(p=2, dim=1) |
theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j]) |
v1 = torch.cross(F.pad(coords[j] - coords[i],(0,1)), F.pad(coords[k] - coords[j],(0,1)), dim=1)[...,2] |
flag = torch.sign((v1)) |
flag[flag==0]=-1 |
thetas[1] = scatter(theta_ijk*flag ,edx_2nd,dim=0,dim_size=edge_index.shape[1],reduce='min') |
thetas[1]=self.theta_expansion(thetas[1]) |
geo_encoding_1st=distances[1][:,None] |
geo_encoding_1st[geo_encoding_1st==0]=1E-10 |
geo_encoding_1st=torch.pow(geo_encoding_1st,-1) |
geo_encoding_2nd = thetas[1] |
geo_encoding=torch.cat([geo_encoding_1st,geo_encoding_2nd],dim=-1) |
else: |
coords_j = coords[edge_index[0]] |
coords_i = coords[edge_index[1]] |
geo_encoding=torch.cat([coords_j,coords_i],dim=-1) |
if edge_rep: |
for lin in self.mlps_list[datasource_idx]: |
geo_encoding=lin(geo_encoding) |
else: |
for lin in self.mlps_list_backup[datasource_idx]: |
geo_encoding=lin(geo_encoding) |
geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype) |
node_feature=self.translinear(input_feature[:,:-2]) |
for interaction in self.interactions_list[datasource_idx]: |
node_feature = interaction(node_feature,geo_encoding,edge_index,is_source) |
return node_feature |
def forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep): |
outputs=[] |
for i in range(len(coords)): |
output=self.single_forward(coords[i],edge_index[i],edge_index_2rd[i], edx_2nd[i],batch[i],input_feature[i],is_source[i],edge_rep,i) |
for lin in self.finalMLP_list[i]: |
output=lin(output) |
outputs.append(output) |
return outputs |
class SPNN(torch.nn.Module): |
def __init__( |
self, |
in_ch, |
hidden_channels, |
activation=torch.nn.ReLU(), |
finaldepth=3, |
batchnorm="False", |
num_input_geofeature=13 |
): |
super(SPNN, self).__init__() |
self.activation = activation |
self.finaldepth = finaldepth |
self.batchnorm = batchnorm |
self.num_input_geofeature=num_input_geofeature |
self.att = Parameter(torch.Tensor(1, hidden_channels),requires_grad=True) |
self.WMLP = ModuleList() |
for i in range(self.finaldepth + 1): |
if i == 0: |
self.WMLP.append(Linear(hidden_channels*2+num_input_geofeature, hidden_channels)) |
else: |
self.WMLP.append(Linear(hidden_channels, hidden_channels)) |
if self.batchnorm == "True": |
self.WMLP.append(nn.BatchNorm1d(hidden_channels)) |
self.WMLP.append(self.activation) |
self.reset_parameters() |
def reset_parameters(self): |
for lin in self.WMLP: |
if isinstance(lin, Linear): |
torch.nn.init.xavier_uniform_(lin.weight) |
lin.bias.data.fill_(0) |
glorot(self.att) |
def forward(self, node_feature,geo_encoding,edge_index,is_source): |
j, i = edge_index |
input_feature=node_feature.clone() |
if node_feature is None: |
concatenated_vector = geo_encoding |
else: |
node_attr_0st = node_feature[i] |
node_attr_1st = node_feature[j] |
concatenated_vector = torch.cat( |
[ |
node_attr_0st, |
node_attr_1st, |
geo_encoding, |
], |
dim=-1, |
) |
x_i = concatenated_vector |
for lin in self.WMLP: |
x_i=lin(x_i) |
input_feature_j=input_feature[edge_index[0]] |
x_i = F.leaky_relu(x_i) |
alpha = F.leaky_relu(x_i * self.att).sum(dim=-1) |
alpha = softmax(alpha, edge_index[1]) |
message=input_feature_j * alpha.unsqueeze(-1) |
out_feature = scatter(message, edge_index[1], dim=0, reduce='add') |
out_feature=input_feature+out_feature |
return out_feature |