Dzy6 commited on
Commit
c7995e9
0 Parent(s):
Files changed (9) hide show
  1. README.md +3 -0
  2. data/north/column +7 -0
  3. data/south/column +7 -0
  4. dataset.py +183 -0
  5. model.py +386 -0
  6. requirements.txt +300 -0
  7. run.sh +4 -0
  8. train.py +421 -0
  9. utils/utils.py +67 -0
README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ # KDD24 Self-consistent Deep Geometric Learning for Heterogeneous Multi-source Spatial Point Data Prediction
2
+
3
+ data is on [dropbox](https://www.dropbox.com/sh/fi5bsxqeuz46h6l/AABSkN6cav7omgvgATX1cs6ga?dl=0)
data/north/column ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ['DM8HA' 'Date.Local' 'relative_humidity_max' 'air_temperature_min'
2
+ 'precipitation' 'air_temperature_max' 'wind_direction' 'solar_radiation'
3
+ 'relative_humidity_min' 'wind_speed' 'elevation' 'CO' 'NH3' 'NOX' 'SO2'
4
+ 'VOC' 'PM25-PRI' 'PM10-PRI' 'population_density_county' 'open_water'
5
+ 'developed' 'bareRock_sand_clay' 'd_forest' 'e_forest' 'm_forest' 'shrub'
6
+ 'grassland' 'pasture' 'crops' 'w_wetlands' 'eh_wetlands' 'cmaq'
7
+ 'Latitude' 'Longitude']
data/south/column ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ ['DM8HA' 'Date.Local' 'relative_humidity_max' 'air_temperature_min'
2
+ 'precipitation' 'air_temperature_max' 'wind_direction' 'solar_radiation'
3
+ 'relative_humidity_min' 'wind_speed' 'elevation' 'CO' 'NH3' 'NOX' 'SO2'
4
+ 'VOC' 'PM25-PRI' 'PM10-PRI' 'population_density_county' 'open_water'
5
+ 'developed' 'bareRock_sand_clay' 'd_forest' 'e_forest' 'm_forest' 'shrub'
6
+ 'grassland' 'pasture' 'crops' 'w_wetlands' 'eh_wetlands' 'cmaq'
7
+ 'Latitude' 'Longitude']
dataset.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import print_function
2
+ import torch.utils.data as data
3
+ import os
4
+ import os.path
5
+ import torch
6
+ import numpy as np
7
+ import pandas as pd
8
+ import sys
9
+ from torch_geometric.nn import knn_graph
10
+ from torch_geometric.data import Data
11
+ from torch_geometric.loader import DataLoader
12
+ from torch_geometric.utils import add_self_loops
13
+ from torch_geometric.data.collate import collate
14
+ from torch_geometric.data.separate import separate
15
+ import pickle
16
+ import time
17
+
18
+ from torch_geometric.data.data import BaseData
19
+ from torch_geometric.data.storage import BaseStorage
20
+ from typing import Any
21
+ def mycollate(data_list):
22
+ r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
23
+ to the internal storage format of
24
+ :class:`~torch_geometric.data.InMemoryDataset`."""
25
+ if len(data_list) == 1:
26
+ return data_list[0], None
27
+ data, slices, _ = collate(
28
+ data_list[0].__class__,
29
+ data_list=data_list,
30
+ increment=False,
31
+ add_batch=False,
32
+ )
33
+ return data, slices
34
+ def myseparate(cls, batch: BaseData, idx: int, slice_dict: Any) -> BaseData:
35
+ data = cls().stores_as(batch)
36
+ # We iterate over each storage object and recursively separate all its attributes:
37
+ for batch_store, data_store in zip(batch.stores, data.stores):
38
+ attrs = set(batch_store.keys())
39
+ for attr in attrs:
40
+ slices = slice_dict[attr]
41
+ data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
42
+ batch, batch_store)
43
+ return data
44
+ def _separate(
45
+ key: str,
46
+ value: Any,
47
+ idx: int,
48
+ slices: Any,
49
+ batch: BaseData,
50
+ store: BaseStorage,
51
+ ) :
52
+ # Narrow a `torch.Tensor` based on `slices`.
53
+ key = str(key)
54
+ cat_dim = batch.__cat_dim__(key, value, store)
55
+ start, end = int(slices[idx]), int(slices[idx + 1])
56
+ value = value.narrow(cat_dim or 0, start, end - start)
57
+ return value
58
+
59
+ def load_point(datasetname="south",k=5,small=[False,50,100]):
60
+ """
61
+ load point and build graph pairs
62
+ """
63
+ print("loading")
64
+ time1=time.time()
65
+ if small[0]:
66
+ print("small south dataset k=5")
67
+ datasetname="south"
68
+ k=5
69
+ filename=os.path.join("data",datasetname,datasetname+f'_{k}.pt')
70
+ [data_graphs1,slices_graphs1,data_graphs2,slices_graphs2]=torch.load(filename)
71
+ flattened_list_graphs1 = [myseparate(cls=data_graphs1.__class__, batch=data_graphs1,idx=i,slice_dict=slices_graphs1) for i in range(small[1]*2)]
72
+ flattened_list_graphs2 = [myseparate(cls=data_graphs2.__class__, batch=data_graphs2,idx=i,slice_dict=slices_graphs2) for i in range(small[2]*2)]
73
+ unflattened_list_graphs1= [flattened_list_graphs1[n:n+2] for n in range(0, len(flattened_list_graphs1), 2)]
74
+ unflattened_list_graphs2= [flattened_list_graphs2[n:n+2] for n in range(0, len(flattened_list_graphs2), 2)]
75
+ print(f"Load data used {time.time()-time1:.1f} seconds")
76
+ return unflattened_list_graphs1,unflattened_list_graphs2
77
+ return process(datasetname,k)
78
+ def process(datasetname="south",k=5):
79
+ time1=time.time()
80
+ """
81
+ build graph pairs
82
+ """
83
+ point_path= os.path.join("data",datasetname,datasetname+".pkl")
84
+ with open(point_path, 'rb') as f:
85
+ data = pickle.load(f)
86
+ graphs1=[]
87
+ graphs2=[]
88
+ for day in data:
89
+ day_d1=day[0]
90
+ day_d2=day[1]
91
+ assert(len(day_d1)<len(day_d2))
92
+ pos1=day_d1[:,-2:]
93
+ edge_index1=knn_graph(pos1,k=k)
94
+ pos2=day_d2[:,-2:]
95
+ edge_index2=knn_graph(pos2,k=k)
96
+ """
97
+ iterately mask point in day_d1, the high fidelity data, to build high fidelity graphs, which share the same structure
98
+ """
99
+ for i in range(day_d1.shape[0]):
100
+ day_d1_copy=day_d1.clone().detach()
101
+ target=day_d1[i,0]
102
+ day_d1_copy[i,0]=0
103
+ target_index=torch.tensor(i,dtype=torch.long)
104
+ is_source = torch.ones(day_d1.shape[0] ,dtype=torch.bool)
105
+ is_source[i]=False
106
+ graph1=Data(x=day_d1_copy,pos=pos1,edge_index=edge_index1,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
107
+ """
108
+ build pairing low fidelity graphs, which add the masked point in day_d1, so structure is changing
109
+ """
110
+ day_plus2=torch.cat([day_d1_copy[i][None,:],day_d2])
111
+ pos_plus2=day_plus2[:,-2:]
112
+ edge_index_plus2=knn_graph(pos_plus2,k=k)
113
+ is_source = torch.ones(day_d2.shape[0]+1 ,dtype=torch.bool)
114
+ is_source[0]=False
115
+ graph2=Data(x=day_plus2,pos=pos_plus2,edge_index=edge_index_plus2,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
116
+ graphs1.append([graph1,graph2])
117
+ """
118
+ iterately mask point in day_d2, the low fidelity data, to build low fidelity graphs, which share the same structure
119
+ """
120
+ for i in range(day_d2.shape[0]):
121
+ day_d2_copy=day_d2.clone().detach()
122
+ target=day_d2[i,0]
123
+ day_d2_copy[i,0]=0
124
+ target_index=torch.tensor(i,dtype=torch.long)
125
+ is_source = torch.ones(day_d2.shape[0] ,dtype=torch.bool)
126
+ is_source[i]=False
127
+ graph2=Data(x=day_d2_copy,pos=pos2,edge_index=edge_index2,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
128
+ """
129
+ build pairing high fidelity graphs, which add the masked point in day_d2, so structure is changing
130
+ """
131
+ day_plus1=torch.cat([day_d2_copy[i][None,:],day_d1])
132
+ pos_plus1=day_plus1[:,-2:]
133
+ edge_index_plus1=knn_graph(pos_plus1,k=k)
134
+ is_source = torch.ones(day_d1.shape[0]+1 ,dtype=torch.bool)
135
+ is_source[0]=False
136
+ graph1=Data(x=day_plus1,pos=pos_plus1,edge_index=edge_index_plus1,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
137
+ graphs2.append([graph1,graph2])
138
+ np.random.shuffle(graphs1)
139
+ np.random.shuffle(graphs2)
140
+ return [graphs1,graphs2]
141
+
142
+ class MergeNeighborDataset(torch.utils.data.Dataset):
143
+ """ Customized dataset for each domain"""
144
+ def __init__(self,X):
145
+ self.X = X # set data
146
+ def __len__(self):
147
+ return len(self.X) # return length
148
+ def __getitem__(self, idx):
149
+ return self.X[idx]
150
+ def kneighbor_point(datasetname="south",k=1,daily=False):
151
+ """
152
+ build k neighbor pairing
153
+ """
154
+ ranking_path= os.path.join("data",datasetname,datasetname+"_ranking.pkl")
155
+ with open(ranking_path, 'rb') as f:
156
+ rankings = pickle.load(f)
157
+ point_path= os.path.join("data",datasetname,datasetname+".pkl")
158
+ with open(point_path, 'rb') as f:
159
+ days = pickle.load(f)
160
+ samples=[]
161
+ for i in range(len(days)):
162
+ day_d1=days[i][0]
163
+ day_d2=days[i][1]
164
+ ranking=rankings[i]
165
+ """
166
+ iterately get point in day_d1, the high fidelity data, to build samples
167
+ """
168
+ sample1 = []
169
+ for j in range(day_d1.shape[0]):
170
+ point1=day_d1[j]
171
+ point1_neighbors=day_d2[ranking[j,:k]]
172
+ point1_neighbor=torch.mean(point1_neighbors,axis=0)
173
+ sample1.append([point1,point1_neighbor])
174
+ if daily:
175
+ samples.append(sample1)
176
+ else:
177
+ samples.extend(sample1)
178
+ if not daily:
179
+ return [samples]
180
+ return samples
181
+
182
+ if __name__ == '__main__':
183
+ 1
model.py ADDED
@@ -0,0 +1,386 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.parallel
4
+ import torch.utils.data
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+ from torch.nn import Parameter
8
+
9
+ from torch_geometric.nn.dense.linear import Linear
10
+ from torch_geometric.nn.conv import MessagePassing
11
+ from torch_geometric.utils import softmax
12
+ # from dataset import
13
+ from torch_geometric.nn.inits import glorot, zeros
14
+
15
+ from torch_scatter import scatter
16
+ from utils.utils import triplets,get_angle,GaussianSmearing
17
+ from torch.nn import ModuleList
18
+ from math import pi as PI
19
+ import math
20
+
21
+ """
22
+ The theory based Grid cell spatial relation encoder,
23
+ See https://openreview.net/forum?id=Syx0Mh05YQ
24
+ Learning Grid Cells as Vector Representation of Self-Position Coupled with Matrix Representation of Self-Motion
25
+ """
26
+ def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius):
27
+ if freq_init == "random":
28
+ # the frequence we use for each block, alpha in paper
29
+ # freq_list shape: (frequency_num)
30
+ freq_list = np.random.random(size=[frequency_num]) * max_radius
31
+ elif freq_init == "geometric":
32
+ # freq_list = []
33
+ # for cur_freq in range(frequency_num):
34
+ # base = 1.0/(np.power(max_radius, cur_freq*1.0/(frequency_num-1)))
35
+ # freq_list.append(base)
36
+
37
+ # freq_list = np.asarray(freq_list)
38
+
39
+ log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) /
40
+ (frequency_num*1.0 - 1))
41
+
42
+ timescales = min_radius * np.exp(
43
+ np.arange(frequency_num).astype(float) * log_timescale_increment)
44
+
45
+ freq_list = 1.0/timescales
46
+
47
+ return freq_list
48
+ class TheoryGridCellSpatialRelationEncoder(nn.Module):
49
+ """
50
+ Given a list of (deltaX,deltaY), encode them using the position encoding function
51
+
52
+ """
53
+ def __init__(self, spa_embed_dim, coord_dim = 2, frequency_num = 16,
54
+ max_radius = 10000, min_radius = 1000, freq_init = "geometric", ffn = None):
55
+ """
56
+ Args:
57
+ spa_embed_dim: the output spatial relation embedding dimention
58
+ coord_dim: the dimention of space, 2D, 3D, or other
59
+ frequency_num: the number of different sinusoidal with different frequencies/wavelengths
60
+ max_radius: the largest context radius this model can handle
61
+ """
62
+ super(TheoryGridCellSpatialRelationEncoder, self).__init__()
63
+ self.frequency_num = frequency_num
64
+ self.coord_dim = coord_dim
65
+ self.max_radius = max_radius
66
+ self.min_radius = min_radius
67
+ self.spa_embed_dim = spa_embed_dim
68
+ self.freq_init = freq_init
69
+
70
+ # the frequence we use for each block, alpha in paper
71
+ self.cal_freq_list()
72
+ self.cal_freq_mat()
73
+
74
+ # there unit vectors which is 120 degree apart from each other
75
+ self.unit_vec1 = np.asarray([1.0, 0.0]) # 0
76
+ self.unit_vec2 = np.asarray([-1.0/2.0, math.sqrt(3)/2.0]) # 120 degree
77
+ self.unit_vec3 = np.asarray([-1.0/2.0, -math.sqrt(3)/2.0]) # 240 degree
78
+
79
+
80
+ self.input_embed_dim = self.cal_input_dim()
81
+ self.ffn = ffn
82
+
83
+ def cal_freq_list(self):
84
+ self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius)
85
+
86
+ def cal_freq_mat(self):
87
+ # freq_mat shape: (frequency_num, 1)
88
+ freq_mat = np.expand_dims(self.freq_list, axis = 1)
89
+ # self.freq_mat shape: (frequency_num, 6)
90
+ self.freq_mat = np.repeat(freq_mat, 6, axis = 1)
91
+
92
+ def cal_input_dim(self):
93
+ # compute the dimention of the encoded spatial relation embedding
94
+ return int(6 * self.frequency_num)
95
+
96
+
97
+ def make_input_embeds(self, coords):
98
+ if type(coords) == np.ndarray:
99
+ assert self.coord_dim == np.shape(coords)[2]
100
+ coords = list(coords)
101
+ elif type(coords) == list:
102
+ assert self.coord_dim == len(coords[0][0])
103
+ elif type(coords) == torch.Tensor:
104
+ assert self.coord_dim == (coords.shape)[2]
105
+ coords=coords.detach().cpu().numpy()
106
+ else:
107
+ raise Exception("Unknown coords data type for GridCellSpatialRelationEncoder")
108
+
109
+ # (batch_size, num_context_pt, coord_dim)
110
+ coords_mat = np.asarray(coords).astype(float)
111
+ batch_size = coords_mat.shape[0]
112
+ num_context_pt = coords_mat.shape[1]
113
+
114
+ # compute the dot product between [deltaX, deltaY] and each unit_vec
115
+ # (batch_size, num_context_pt, 1)
116
+ angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis = -1)
117
+ # (batch_size, num_context_pt, 1)
118
+ angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis = -1)
119
+ # (batch_size, num_context_pt, 1)
120
+ angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis = -1)
121
+
122
+ # (batch_size, num_context_pt, 6)
123
+ angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis = -1)
124
+ # (batch_size, num_context_pt, 1, 6)
125
+ angle_mat = np.expand_dims(angle_mat, axis = -2)
126
+ # (batch_size, num_context_pt, frequency_num, 6)
127
+ angle_mat = np.repeat(angle_mat, self.frequency_num, axis = -2)
128
+ # (batch_size, num_context_pt, frequency_num, 6)
129
+ angle_mat = angle_mat * self.freq_mat
130
+ # (batch_size, num_context_pt, frequency_num*6)
131
+ spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1))
132
+
133
+ # make sinuniod function
134
+ # sin for 2i, cos for 2i+1
135
+ # spr_embeds: (batch_size, num_context_pt, frequency_num*6=input_embed_dim)
136
+ spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) # dim 2i
137
+ spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) # dim 2i+1
138
+
139
+ return spr_embeds
140
+
141
+
142
+ def forward(self, coords):
143
+ """
144
+ Given a list of coords (deltaX, deltaY), give their spatial relation embedding
145
+ Args:
146
+ coords: a python list with shape (batch_size, num_context_pt, coord_dim)
147
+ Return:
148
+ sprenc: Tensor shape (batch_size, num_context_pt, spa_embed_dim)
149
+ """
150
+ spr_embeds = self.make_input_embeds(coords)
151
+
152
+ # spr_embeds: (batch_size, num_context_pt, input_embed_dim)
153
+ spr_embeds = torch.FloatTensor(spr_embeds)
154
+ if self.ffn is not None:
155
+ return self.ffn(spr_embeds)
156
+ else:
157
+ return spr_embeds
158
+ theoryencoder=TheoryGridCellSpatialRelationEncoder(spa_embed_dim=8)
159
+
160
+ class GFusion(nn.Module):
161
+ def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,num_of_datasources=2,share=True,batchnorm="False"):
162
+ super(GFusion,self).__init__()
163
+ self.training=True
164
+ self.h_channel = h_channel
165
+ self.input_featuresize=input_featuresize
166
+ self.localdepth = localdepth
167
+ self.num_interactions=num_interactions
168
+ self.finaldepth=finaldepth
169
+ self.batchnorm = batchnorm
170
+ self.activation=nn.ReLU()
171
+
172
+ num_gaussians=(1,12)
173
+ self.theta_expansion = GaussianSmearing(-PI, PI, num_gaussians[1])
174
+ self.mlps_list = ModuleList()
175
+ if int(share[0])==1:
176
+ mlp_geo = ModuleList()
177
+ for i in range(self.localdepth):
178
+ if i == 0:
179
+ mlp_geo.append(Linear(sum(num_gaussians), h_channel))
180
+ else:
181
+ mlp_geo.append(Linear(h_channel, h_channel))
182
+ if self.batchnorm == "True":
183
+ mlp_geo.append(nn.BatchNorm1d(h_channel))
184
+ mlp_geo.append(self.activation)
185
+ for i in range(num_of_datasources):
186
+ self.mlps_list.append(mlp_geo)
187
+ else:
188
+ for i in range(num_of_datasources):
189
+ mlp_geo = ModuleList()
190
+ for i in range(self.localdepth):
191
+ if i == 0:
192
+ mlp_geo.append(Linear(sum(num_gaussians), h_channel))
193
+ else:
194
+ mlp_geo.append(Linear(h_channel, h_channel))
195
+ if self.batchnorm == "True":
196
+ mlp_geo.append(nn.BatchNorm1d(h_channel))
197
+ mlp_geo.append(self.activation)
198
+ self.mlps_list.append(mlp_geo)
199
+ self.mlps_list_backup = ModuleList()
200
+ for i in range(num_of_datasources):
201
+ mlp_geo = ModuleList()
202
+ for i in range(self.localdepth):
203
+ if i == 0:
204
+ mlp_geo.append(Linear(4, h_channel)) # for FN version
205
+ else:
206
+ mlp_geo.append(Linear(h_channel, h_channel))
207
+ if self.batchnorm == "True":
208
+ mlp_geo.append(nn.BatchNorm1d(h_channel))
209
+ mlp_geo.append(self.activation)
210
+ self.mlps_list_backup.append(mlp_geo)
211
+ self.translinear=Linear(input_featuresize+1, self.h_channel)
212
+ self.interactions_list = ModuleList()
213
+ if int(share[1])==1:
214
+ interactions= ModuleList()
215
+ for i in range(self.num_interactions):
216
+ block = SPNN(
217
+ in_ch=self.input_featuresize,
218
+ hidden_channels=self.h_channel,
219
+ activation=self.activation,
220
+ finaldepth=self.finaldepth,
221
+ batchnorm=self.batchnorm,
222
+ num_input_geofeature=self.h_channel
223
+ )
224
+ interactions.append(block)
225
+ for i in range(num_of_datasources):
226
+ self.interactions_list.append(interactions)
227
+ else:
228
+ for i in range(num_of_datasources):
229
+ interactions= ModuleList()
230
+ for i in range(self.num_interactions):
231
+ block = SPNN(
232
+ in_ch=self.input_featuresize,
233
+ hidden_channels=self.h_channel,
234
+ activation=self.activation,
235
+ finaldepth=self.finaldepth,
236
+ batchnorm=self.batchnorm,
237
+ num_input_geofeature=self.h_channel
238
+ )
239
+ interactions.append(block)
240
+ self.interactions_list.append(interactions)
241
+ self.finalMLP_list = ModuleList()
242
+ if int(share[2])==1:
243
+ finalMLP=ModuleList()
244
+ for i in range(self.finaldepth + 1):
245
+ finalMLP.append(Linear(self.h_channel, self.h_channel))
246
+ if self.batchnorm == "True":
247
+ finalMLP.append(nn.BatchNorm1d(self.h_channel))
248
+ finalMLP.append(self.activation)
249
+ finalMLP.append(Linear(self.h_channel, 1))
250
+ for i in range(num_of_datasources):
251
+ self.finalMLP_list.append(finalMLP)
252
+ else:
253
+ for i in range(num_of_datasources):
254
+ finalMLP=ModuleList()
255
+ for i in range(self.finaldepth + 1):
256
+ finalMLP.append(Linear(self.h_channel, self.h_channel))
257
+ if self.batchnorm == "True":
258
+ finalMLP.append(nn.BatchNorm1d(self.h_channel))
259
+ finalMLP.append(self.activation)
260
+ finalMLP.append(Linear(self.h_channel, 1))
261
+ self.finalMLP_list.append(finalMLP)
262
+ self.reset_parameters()
263
+ def reset_parameters(self):
264
+ for i in range(len(self.mlps_list)):
265
+ for lin in self.mlps_list[i]:
266
+ if isinstance(lin, Linear):
267
+ torch.nn.init.xavier_uniform_(lin.weight)
268
+ lin.bias.data.fill_(0)
269
+ for i in range(len(self.interactions_list)):
270
+ for block in self.interactions_list[i]:
271
+ block.reset_parameters()
272
+ for finalMLP in self.finalMLP_list:
273
+ for lin in finalMLP:
274
+ if isinstance(lin, Linear):
275
+ torch.nn.init.xavier_uniform_(lin.weight)
276
+ lin.bias.data.fill_(0)
277
+
278
+ def single_forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep,datasource_idx):
279
+ distances={}
280
+ thetas={}
281
+ if edge_rep:
282
+ i, j, k = edge_index_2rd
283
+ distances[1]=(coords[edge_index[0]] - coords[edge_index[1]]).norm(p=2, dim=1)
284
+ theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j])
285
+ v1 = torch.cross(F.pad(coords[j] - coords[i],(0,1)), F.pad(coords[k] - coords[j],(0,1)), dim=1)[...,2]
286
+ flag = torch.sign((v1))
287
+ flag[flag==0]=-1
288
+ thetas[1] = scatter(theta_ijk*flag ,edx_2nd,dim=0,dim_size=edge_index.shape[1],reduce='min')
289
+ thetas[1]=self.theta_expansion(thetas[1])
290
+ geo_encoding_1st=distances[1][:,None]
291
+ geo_encoding_1st[geo_encoding_1st==0]=1E-10
292
+ geo_encoding_1st=torch.pow(geo_encoding_1st,-1)
293
+ geo_encoding_2nd = thetas[1]
294
+ geo_encoding=torch.cat([geo_encoding_1st,geo_encoding_2nd],dim=-1)
295
+ else:
296
+ # coords=theoryencoder(coords[None,:])
297
+ # coords=coords[0].to("cuda")
298
+
299
+ coords_j = coords[edge_index[0]]
300
+ coords_i = coords[edge_index[1]]
301
+ geo_encoding=torch.cat([coords_j,coords_i],dim=-1)
302
+ if edge_rep:
303
+ for lin in self.mlps_list[datasource_idx]:
304
+ geo_encoding=lin(geo_encoding)
305
+ else:
306
+ for lin in self.mlps_list_backup[datasource_idx]:
307
+ geo_encoding=lin(geo_encoding)
308
+ geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype)
309
+ node_feature=self.translinear(input_feature[:,:-2])
310
+ for interaction in self.interactions_list[datasource_idx]:
311
+ node_feature = interaction(node_feature,geo_encoding,edge_index,is_source)
312
+ return node_feature
313
+ def forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep):
314
+ outputs=[]
315
+ for i in range(len(coords)):
316
+ output=self.single_forward(coords[i],edge_index[i],edge_index_2rd[i], edx_2nd[i],batch[i],input_feature[i],is_source[i],edge_rep,i)
317
+ for lin in self.finalMLP_list[i]:
318
+ output=lin(output)
319
+ outputs.append(output)
320
+ return outputs
321
+
322
+ class SPNN(torch.nn.Module):
323
+ def __init__(
324
+ self,
325
+ in_ch,
326
+ hidden_channels,
327
+ activation=torch.nn.ReLU(),
328
+ finaldepth=3,
329
+ batchnorm="False",
330
+ num_input_geofeature=13
331
+ ):
332
+ super(SPNN, self).__init__()
333
+ self.activation = activation
334
+ self.finaldepth = finaldepth
335
+ self.batchnorm = batchnorm
336
+ self.num_input_geofeature=num_input_geofeature
337
+ self.att = Parameter(torch.Tensor(1, hidden_channels),requires_grad=True)
338
+
339
+ self.WMLP = ModuleList()
340
+ for i in range(self.finaldepth + 1):
341
+ if i == 0:
342
+ self.WMLP.append(Linear(hidden_channels*2+num_input_geofeature, hidden_channels))
343
+ else:
344
+ self.WMLP.append(Linear(hidden_channels, hidden_channels))
345
+ if self.batchnorm == "True":
346
+ self.WMLP.append(nn.BatchNorm1d(hidden_channels))
347
+ self.WMLP.append(self.activation)
348
+ self.reset_parameters()
349
+
350
+ def reset_parameters(self):
351
+ for lin in self.WMLP:
352
+ if isinstance(lin, Linear):
353
+ torch.nn.init.xavier_uniform_(lin.weight)
354
+ lin.bias.data.fill_(0)
355
+ glorot(self.att)
356
+ def forward(self, node_feature,geo_encoding,edge_index,is_source):
357
+ j, i = edge_index
358
+ input_feature=node_feature.clone()
359
+ if node_feature is None:
360
+ concatenated_vector = geo_encoding
361
+ else:
362
+ node_attr_0st = node_feature[i]
363
+ node_attr_1st = node_feature[j]
364
+ concatenated_vector = torch.cat(
365
+ [
366
+ node_attr_0st,
367
+ node_attr_1st,
368
+ geo_encoding,
369
+ ],
370
+ dim=-1,
371
+ )
372
+ x_i = concatenated_vector
373
+ for lin in self.WMLP:
374
+ x_i=lin(x_i)
375
+ input_feature_j=input_feature[edge_index[0]]
376
+ x_i = F.leaky_relu(x_i)
377
+ alpha = F.leaky_relu(x_i * self.att).sum(dim=-1)
378
+ alpha = softmax(alpha, edge_index[1])
379
+
380
+ message=input_feature_j * alpha.unsqueeze(-1)
381
+ out_feature = scatter(message, edge_index[1], dim=0, reduce='add')
382
+ out_feature=input_feature+out_feature
383
+
384
+ return out_feature
385
+
386
+
requirements.txt ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file may be used to create an environment using:
2
+ # $ conda create --name <env> --file <this file>
3
+ # platform: linux-64
4
+ _libgcc_mutex=0.1=conda_forge
5
+ _openmp_mutex=4.5=2_gnu
6
+ _py-xgboost-mutex=2.0=cpu_0
7
+ absl-py=1.3.0=py310h06a4308_0
8
+ aiohttp=3.8.1=py310h7f8727e_1
9
+ aiosignal=1.2.0=pyhd3eb1b0_0
10
+ anyio=3.6.2=pyhd8ed1ab_0
11
+ argon2-cffi=21.3.0=pyhd8ed1ab_0
12
+ argon2-cffi-bindings=21.2.0=py310h5764c6d_3
13
+ asttokens=2.0.5=pyhd3eb1b0_0
14
+ async-timeout=4.0.2=py310h06a4308_0
15
+ attrs=21.4.0=pyhd3eb1b0_0
16
+ autograd=1.5=pyhd8ed1ab_0
17
+ autopep8=1.6.0=pyhd3eb1b0_1
18
+ backcall=0.2.0=pyhd3eb1b0_0
19
+ beautifulsoup4=4.11.2=pyha770c72_0
20
+ blas=1.0=mkl
21
+ bleach=6.0.0=pyhd8ed1ab_0
22
+ blinker=1.4=py310h06a4308_0
23
+ blosc=1.21.1=h83bc5f7_3
24
+ boost-cpp=1.74.0=h75c5d50_8
25
+ bottleneck=1.3.5=py310ha9d4c09_0
26
+ branca=0.5.0=pyhd8ed1ab_0
27
+ brotli=1.0.9=h5eee18b_7
28
+ brotli-bin=1.0.9=h5eee18b_7
29
+ brotlipy=0.7.0=py310h7f8727e_1002
30
+ bzip2=1.0.8=h7b6447c_0
31
+ c-ares=1.18.1=h7f8727e_0
32
+ ca-certificates=2022.12.7=ha878542_0
33
+ cachetools=4.2.2=pyhd3eb1b0_0
34
+ cairo=1.16.0=h19f5f5c_2
35
+ certifi=2022.12.7=pyhd8ed1ab_0
36
+ cffi=1.15.1=py310h74dc2b5_0
37
+ cfitsio=4.1.0=hd9d235c_0
38
+ charset-normalizer=2.0.4=pyhd3eb1b0_0
39
+ click=8.0.4=py310h06a4308_0
40
+ click-plugins=1.1.1=pyhd3eb1b0_0
41
+ cligj=0.7.2=pyhd3eb1b0_0
42
+ cryptography=38.0.1=py310h9ce1e76_0
43
+ cudatoolkit=11.6.0=hecad31d_10
44
+ curl=7.84.0=h5eee18b_0
45
+ cycler=0.11.0=pyhd3eb1b0_0
46
+ dataclasses=0.8=pyh6d0b6a4_7
47
+ debugpy=1.5.1=py310h295c915_0
48
+ decorator=5.1.1=pyhd3eb1b0_0
49
+ defusedxml=0.7.1=pyhd8ed1ab_0
50
+ entrypoints=0.4=py310h06a4308_0
51
+ executing=0.8.3=pyhd3eb1b0_0
52
+ expat=2.4.9=h6a678d5_0
53
+ ffmpeg=4.2.2=h20bf706_0
54
+ fftw=3.3.9=h27cfd23_1
55
+ fiona=1.8.21=py310h60a68a4_2
56
+ flit-core=3.8.0=pyhd8ed1ab_0
57
+ folium=0.12.1.post1=pyhd8ed1ab_1
58
+ font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
59
+ font-ttf-inconsolata=2.001=hcb22688_0
60
+ font-ttf-source-code-pro=2.030=hd3eb1b0_0
61
+ font-ttf-ubuntu=0.83=h8b1ccd4_0
62
+ fontconfig=2.14.0=h8e229c2_0
63
+ fonts-anaconda=1=h8fa9717_0
64
+ fonts-conda-ecosystem=1=hd3eb1b0_0
65
+ fonttools=4.25.0=pyhd3eb1b0_0
66
+ freetype=2.11.0=h70c0345_0
67
+ freexl=1.0.6=h27cfd23_0
68
+ frozenlist=1.2.0=py310h7f8727e_1
69
+ future=0.18.3=pyhd8ed1ab_0
70
+ gdal=3.5.1=py310hce6f0df_1
71
+ geographiclib=1.52=pyhd8ed1ab_0
72
+ geopandas=0.11.1=pyhd8ed1ab_0
73
+ geopandas-base=0.11.1=pyha770c72_0
74
+ geopy=2.2.0=pyhd8ed1ab_0
75
+ geos=3.11.0=h27087fc_0
76
+ geotiff=1.7.1=h4fc65e6_3
77
+ gettext=0.21.0=hf68c758_0
78
+ giflib=5.2.1=h7b6447c_0
79
+ glib=2.72.1=h6239696_0
80
+ glib-tools=2.72.1=h6239696_0
81
+ gmp=6.2.1=h295c915_3
82
+ gnutls=3.6.15=he1e5248_0
83
+ google-auth=2.6.0=pyhd3eb1b0_0
84
+ google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
85
+ gpy=1.10.0=py310hde88566_3
86
+ grpcio=1.42.0=py310hce63b2e_0
87
+ hdf4=4.2.15=h9772cbc_4
88
+ hdf5=1.12.1=nompi_h2386368_104
89
+ icu=70.1=h27087fc_0
90
+ idna=3.4=py310h06a4308_0
91
+ importlib-metadata=4.11.3=py310h06a4308_0
92
+ importlib_resources=5.12.0=pyhd8ed1ab_0
93
+ intel-openmp=2021.4.0=h06a4308_3561
94
+ ipykernel=6.15.2=py310h06a4308_0
95
+ ipython=8.4.0=py310h06a4308_0
96
+ ipython_genutils=0.2.0=py_1
97
+ jedi=0.18.1=py310h06a4308_1
98
+ jinja2=3.1.2=py310h06a4308_0
99
+ joblib=1.1.1=py310h06a4308_0
100
+ jpeg=9e=h7f8727e_0
101
+ json-c=0.16=h5eee18b_0
102
+ jsonschema=4.17.3=pyhd8ed1ab_0
103
+ jupyter_client=7.3.4=py310h06a4308_0
104
+ jupyter_core=4.10.0=py310h06a4308_0
105
+ jupyter_server=1.23.4=py310h06a4308_0
106
+ jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
107
+ kealib=1.4.15=hfe1a663_0
108
+ keyutils=1.6.1=h166bdaf_0
109
+ kiwisolver=1.4.2=py310h295c915_0
110
+ krb5=1.19.3=h3790be6_0
111
+ lame=3.100=h7b6447c_0
112
+ lcms2=2.12=h3be6417_0
113
+ ld_impl_linux-64=2.38=h1181459_1
114
+ lerc=3.0=h295c915_0
115
+ libbrotlicommon=1.0.9=h5eee18b_7
116
+ libbrotlidec=1.0.9=h5eee18b_7
117
+ libbrotlienc=1.0.9=h5eee18b_7
118
+ libcurl=7.84.0=h91b91d3_0
119
+ libdap4=3.20.6=hd7c4107_2
120
+ libdeflate=1.8=h7f8727e_5
121
+ libedit=3.1.20210910=h7f8727e_0
122
+ libev=4.33=h7f8727e_1
123
+ libffi=3.4.2=h7f98852_5
124
+ libgcc-ng=12.1.0=h8d9b700_16
125
+ libgdal=3.5.1=h32640fd_1
126
+ libgfortran-ng=11.2.0=h00389a5_1
127
+ libgfortran5=11.2.0=h1234567_1
128
+ libglib=2.72.1=h2d90d5f_0
129
+ libgomp=12.1.0=h8d9b700_16
130
+ libiconv=1.16=h7f8727e_2
131
+ libidn2=2.3.2=h7f8727e_0
132
+ libkml=1.3.0=h238a007_1014
133
+ libnetcdf=4.8.1=nompi_h329d8a1_102
134
+ libnghttp2=1.46.0=hce63b2e_0
135
+ libnsl=2.0.0=h5eee18b_0
136
+ libopus=1.3.1=h7b6447c_0
137
+ libpng=1.6.37=hbc83047_0
138
+ libpq=14.5=hd77ab85_0
139
+ libprotobuf=3.20.1=h4ff587b_0
140
+ libpysal=4.1.1=py_0
141
+ librttopo=1.1.0=hf730bdb_11
142
+ libsodium=1.0.18=h7b6447c_0
143
+ libspatialindex=1.9.3=h2531618_0
144
+ libspatialite=5.0.1=h38b5f51_18
145
+ libsqlite=3.39.3=h753d276_0
146
+ libssh2=1.10.0=h8f2d780_0
147
+ libstdcxx-ng=12.1.0=ha89aaad_16
148
+ libtasn1=4.16.0=h27cfd23_0
149
+ libtiff=4.4.0=hecacb30_0
150
+ libunistring=0.9.10=h27cfd23_0
151
+ libuuid=2.32.1=h7f98852_1000
152
+ libvpx=1.7.0=h439df22_0
153
+ libwebp=1.2.4=h11a3e52_0
154
+ libwebp-base=1.2.4=h5eee18b_0
155
+ libxcb=1.15=h7f8727e_0
156
+ libxgboost=1.7.1=cpu_ha3b9936_0
157
+ libxml2=2.9.14=h22db469_4
158
+ libzip=1.8.0=h5cef20c_0
159
+ libzlib=1.2.12=h166bdaf_2
160
+ littleutils=0.2.2=py_0
161
+ llvm-openmp=8.0.1=hc9558a2_0
162
+ lz4-c=1.9.3=h295c915_1
163
+ mapclassify=2.4.3=pyhd3eb1b0_0
164
+ markdown=3.3.4=py310h06a4308_0
165
+ markupsafe=2.1.1=py310h7f8727e_0
166
+ matplotlib-base=3.5.2=py310hf590b9c_0
167
+ matplotlib-inline=0.1.6=py310h06a4308_0
168
+ mgwr=2.1.2=py_0
169
+ mistune=2.0.5=pyhd8ed1ab_0
170
+ mkl=2021.4.0=h06a4308_640
171
+ mkl-service=2.4.0=py310h7f8727e_0
172
+ mkl_fft=1.3.1=py310hd6ae3a3_0
173
+ mkl_random=1.2.2=py310h00e6091_0
174
+ multidict=6.0.2=py310h5eee18b_0
175
+ munch=2.5.0=pyhd3eb1b0_0
176
+ munkres=1.1.4=py_0
177
+ nbclassic=0.5.3=pyhb4ecaf3_3
178
+ nbclient=0.5.13=pyhd8ed1ab_0
179
+ nbconvert=7.2.9=pyhd8ed1ab_0
180
+ nbconvert-core=7.2.9=pyhd8ed1ab_0
181
+ nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
182
+ nbformat=5.7.3=pyhd8ed1ab_0
183
+ ncurses=6.3=h5eee18b_3
184
+ nest-asyncio=1.5.5=py310h06a4308_0
185
+ nettle=3.7.3=hbbd107a_1
186
+ networkx=2.8.4=py310h06a4308_0
187
+ notebook=6.5.3=pyha770c72_0
188
+ notebook-shim=0.2.2=pyhd8ed1ab_0
189
+ nspr=4.33=h295c915_0
190
+ nss=3.78=h2350873_0
191
+ numexpr=2.8.3=py310hcea2de6_0
192
+ numpy=1.23.1=py310h1794996_0
193
+ numpy-base=1.23.1=py310hcba007f_0
194
+ oauthlib=3.2.1=py310h06a4308_0
195
+ ogb=1.3.5=pyhd8ed1ab_0
196
+ openai=0.28.0=pypi_0
197
+ openh264=2.1.1=h4ff587b_0
198
+ openjpeg=2.4.0=h3ad879b_0
199
+ openssl=1.1.1t=h0b41bf4_0
200
+ outdated=0.2.2=pyhd8ed1ab_0
201
+ packaging=21.3=pyhd3eb1b0_0
202
+ pandas=1.4.3=py310h6a678d5_0
203
+ pandoc=2.19.2=ha770c72_0
204
+ pandocfilters=1.5.0=pyhd8ed1ab_0
205
+ paramz=0.9.5=py_0
206
+ parso=0.8.3=pyhd3eb1b0_0
207
+ pcre=8.45=h295c915_0
208
+ pexpect=4.8.0=pyhd3eb1b0_3
209
+ pickleshare=0.7.5=pyhd3eb1b0_1003
210
+ pillow=9.2.0=py310hace64e9_1
211
+ pip=22.1.2=py310h06a4308_0
212
+ pixman=0.40.0=h7f8727e_1
213
+ pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
214
+ poppler=22.04.0=h1434ded_1
215
+ poppler-data=0.4.11=h06a4308_0
216
+ postgresql=14.5=hfdbbde3_0
217
+ proj=9.0.1=h93bde94_1
218
+ prometheus_client=0.16.0=pyhd8ed1ab_0
219
+ prompt-toolkit=3.0.20=pyhd3eb1b0_0
220
+ protobuf=3.20.1=py310h295c915_0
221
+ psutil=5.9.0=py310h5eee18b_0
222
+ ptyprocess=0.7.0=pyhd3eb1b0_2
223
+ pure_eval=0.2.2=pyhd3eb1b0_0
224
+ py-xgboost=1.7.1=cpu_py310hd1aba9c_0
225
+ pyasn1=0.4.8=pyhd3eb1b0_0
226
+ pyasn1-modules=0.2.8=py_0
227
+ pycodestyle=2.8.0=pyhd3eb1b0_0
228
+ pycparser=2.21=pyhd3eb1b0_0
229
+ pygments=2.11.2=pyhd3eb1b0_0
230
+ pyjwt=2.4.0=py310h06a4308_0
231
+ pyopenssl=22.0.0=pyhd3eb1b0_0
232
+ pyparsing=3.0.9=py310h06a4308_0
233
+ pyproj=3.4.0=py310hf94497c_0
234
+ pyrsistent=0.19.3=py310h1fa729e_0
235
+ pysocks=1.7.1=py310h06a4308_0
236
+ python=3.10.6=h582c2e5_0_cpython
237
+ python-dateutil=2.8.2=pyhd3eb1b0_0
238
+ python-fastjsonschema=2.16.3=pyhd8ed1ab_0
239
+ python_abi=3.10=2_cp310
240
+ pytorch=1.12.1=py3.10_cuda11.6_cudnn8.3.2_0
241
+ pytorch-mutex=1.0=cuda
242
+ pytz=2022.1=py310h06a4308_0
243
+ pyzmq=23.2.0=py310h6a678d5_0
244
+ readline=8.1.2=h7f8727e_1
245
+ requests=2.28.1=py310h06a4308_0
246
+ requests-oauthlib=1.3.0=py_0
247
+ rsa=4.7.2=pyhd3eb1b0_1
248
+ rtree=0.9.7=py310h06a4308_1
249
+ scikit-learn=1.1.3=py310h6a678d5_0
250
+ scipy=1.9.3=py310hd5efca6_0
251
+ send2trash=1.8.0=pyhd8ed1ab_0
252
+ setuptools=63.4.1=py310h06a4308_0
253
+ shapely=1.8.4=py310h5e49deb_0
254
+ six=1.16.0=pyhd3eb1b0_1
255
+ snappy=1.1.9=h295c915_0
256
+ sniffio=1.3.0=pyhd8ed1ab_0
257
+ soupsieve=2.3.2.post1=pyhd8ed1ab_0
258
+ spglm=1.0.8=py_0
259
+ spreg=1.3.0=pyhd8ed1ab_0
260
+ sqlite=3.39.2=h5082296_0
261
+ stack_data=0.2.0=pyhd3eb1b0_0
262
+ tensorboard=2.10.1=pyhd8ed1ab_0
263
+ tensorboard-data-server=0.6.0=py310hca6d32c_0
264
+ tensorboard-plugin-wit=1.8.1=py310h06a4308_0
265
+ terminado=0.17.1=pyh41d4057_0
266
+ threadpoolctl=2.2.0=pyh0d69192_0
267
+ tiledb=2.9.5=h1e4a385_0
268
+ tinycss2=1.2.1=pyhd8ed1ab_0
269
+ tk=8.6.12=h1ccaba5_0
270
+ toml=0.10.2=pyhd3eb1b0_0
271
+ torch-cluster=1.6.0=pypi_0
272
+ torch-geometric=2.1.0.post1=pypi_0
273
+ torch-scatter=2.0.9=pypi_0
274
+ torch-sparse=0.6.15=pypi_0
275
+ torch-spline-conv=1.2.1=pypi_0
276
+ torch-tb-profiler=0.4.0=pypi_0
277
+ torchaudio=0.12.1=py310_cu116
278
+ torchvision=0.13.1=py310_cu116
279
+ tornado=6.2=py310h5eee18b_0
280
+ tqdm=4.64.0=py310h06a4308_0
281
+ traitlets=5.1.1=pyhd3eb1b0_0
282
+ typing_extensions=4.3.0=py310h06a4308_0
283
+ tzcode=2022c=h166bdaf_0
284
+ tzdata=2022a=hda174b7_0
285
+ urllib3=1.26.12=py310h06a4308_0
286
+ wcwidth=0.2.5=pyhd3eb1b0_0
287
+ webencodings=0.5.1=py_1
288
+ websocket-client=1.5.1=pyhd8ed1ab_0
289
+ werkzeug=2.0.3=pyhd3eb1b0_0
290
+ wheel=0.37.1=pyhd3eb1b0_0
291
+ x264=1!157.20191217=h7b6447c_0
292
+ xerces-c=3.2.3=h55805fa_5
293
+ xgboost=1.7.1=cpu_py310hd1aba9c_0
294
+ xyzservices=2022.9.0=py310h06a4308_0
295
+ xz=5.2.6=h166bdaf_0
296
+ yarl=1.8.1=py310h5eee18b_0
297
+ zeromq=4.3.4=h2531618_0
298
+ zipp=3.8.0=py310h06a4308_0
299
+ zlib=1.2.12=h7f8727e_2
300
+ zstd=1.5.2=ha4553b6_0
run.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # dataset=north
2
+ # dataset=south
3
+ dataset=flu
4
+ python3 ./train.py --dataset $dataset --manualSeed True --man_seed 5770
train.py ADDED
@@ -0,0 +1,421 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os
3
+ import random
4
+ import torch
5
+ import pandas as pd
6
+ import numpy as np
7
+ import time
8
+ import torch.optim as optim
9
+
10
+ from matplotlib import cm
11
+ import matplotlib.pyplot as plt
12
+ import json
13
+ from model import GFusion
14
+ import torch.nn.functional as F
15
+ from torch_geometric.data import Data
16
+ from torch_geometric.loader import DataLoader
17
+
18
+ from torch_geometric.utils import add_self_loops
19
+ from torch.nn.functional import softmax
20
+ from torch_geometric.nn import knn_graph
21
+ import copy
22
+
23
+ torch.autograd.set_detect_anomaly(True)
24
+ from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc
25
+ from sklearn.feature_selection import r_regression
26
+ import pickle
27
+ from utils.utils import triplets,unique,pos2key
28
+ from torch.utils.tensorboard import SummaryWriter
29
+ from datetime import datetime
30
+ import dataset
31
+
32
+ def count_parameters(model):
33
+ return sum(p.numel() for p in model.parameters() if p.requires_grad)
34
+
35
+ blue = lambda x: '\033[94m' + x + '\033[0m'
36
+ red = lambda x: '\033[31m' + x + '\033[0m'
37
+ green = lambda x: '\033[32m' + x + '\033[0m'
38
+ yellow = lambda x: '\033[33m' + x + '\033[0m'
39
+ greenline = lambda x: '\033[42m' + x + '\033[0m'
40
+ yellowline = lambda x: '\033[43m' + x + '\033[0m'
41
+
42
+ def get_args():
43
+ parser = argparse.ArgumentParser()
44
+ parser.add_argument('--log', type=str, default="True")
45
+ parser.add_argument('--loadmodel', type=str, default="False")
46
+ parser.add_argument('--split_dataset', type=str, default="False")
47
+ parser.add_argument('--model', type=str, default="GFusion")
48
+
49
+ # ablation
50
+ parser.add_argument('--edge_rep', type=str, default="True")
51
+ parser.add_argument('--single_high', type=str, default="False")
52
+ parser.add_argument('--fidelity_train', type=str, default="True")
53
+ parser.add_argument('--fidelity_low_weight', type=float, default=-1.0)
54
+ parser.add_argument('--share', type=str, default="101")
55
+
56
+ parser.add_argument('--dataset', type=str, default='flu')
57
+ parser.add_argument('--manualSeed', type=str, default="False")
58
+ parser.add_argument('--man_seed', type=int, default=12345)
59
+ parser.add_argument('--test_per_round', type=int, default=10)
60
+ parser.add_argument('--patience', type=int, default=30) #scheduler
61
+ parser.add_argument('--nepoch', type=int, default=201)
62
+ parser.add_argument('--lr', type=float, default=1e-3)
63
+ parser.add_argument('--activation', type=str, default='relu')#'lrelu'
64
+ parser.add_argument('--batchSize', type=int, default=512)
65
+
66
+ parser.add_argument('--num_neighbors', type=int, default=3)
67
+ parser.add_argument('--regression_loss', type=str, default='l2')
68
+
69
+ parser.add_argument('--h_ch', type=int, default=16)
70
+ parser.add_argument('--localdepth', type=int, default=1) # mlp(distance) mlp(theta) >=1
71
+ parser.add_argument('--num_interactions', type=int, default=1) #>=1
72
+ parser.add_argument('--finaldepth', type=int, default=3) # mlp(concat node_attr and geo_encoding)
73
+
74
+ args = parser.parse_args()
75
+ args.log=True if args.log=="True" else False
76
+ args.loadmodel=True if args.loadmodel=="True" else False
77
+ args.split_dataset=True if args.split_dataset=="True" else False
78
+ args.edge_rep=True if args.edge_rep=="True" else False
79
+ args.single_high=True if args.single_high=="True" else False
80
+ args.fidelity_train=True if args.fidelity_train=="True" and args.single_high is False and args.fidelity_low_weight==-1.0 else False
81
+ args.manualSeed=True if args.manualSeed=="True" else False
82
+ args.save_dir=os.path.join('./save/',args.dataset)
83
+ return args
84
+
85
+ def main(args,train_Loader,val_Loader,test_Loader):
86
+ if flag:
87
+ return
88
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
89
+ measure_Pearsonr=r_regression
90
+ criterion_l1 = torch.nn.L1Loss() #reduction='sum'
91
+ criterion_l2 = torch.nn.MSELoss()
92
+ criterion=criterion_l1 if args.regression_loss=='l1' else criterion_l2
93
+ if args.model in ['GFusion']:
94
+ def myL1(pred,true,weight=None,reduction='mean'):
95
+ loss=(abs(pred-true))
96
+ num=len(pred)
97
+ if weight is not None:
98
+ loss=[weight[i]*loss[i] for i in range(num)]
99
+ loss=sum(loss)
100
+ if reduction=='mean':
101
+ loss=loss/num
102
+ return loss
103
+ def myL2(pred,true,weight=None,reduction='mean'):
104
+ loss=((pred-true)**2)
105
+ num=len(pred)
106
+ if weight is not None:
107
+ loss=[weight[i]*loss[i] for i in range(num)]
108
+ loss=sum(loss)
109
+ if reduction=='mean':
110
+ loss=loss/num
111
+ return loss
112
+ criterion=myL1 if args.regression_loss=='l1' else myL2
113
+ num_of_fidelities=len(train_graphs[0])
114
+
115
+ def reweight_fidelity():
116
+ if args.single_high:
117
+ weighted_fidelity_weight[0]=1
118
+ weighted_fidelity_weight[1]=0
119
+ elif args.fidelity_low_weight!=-1.0:
120
+ weighted_fidelity_weight[0]=1
121
+ weighted_fidelity_weight[1]=args.fidelity_low_weight
122
+ else:
123
+ exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
124
+ fsum=sum(exped_f)
125
+ for i in range(num_of_fidelities):
126
+ weighted_fidelity_weight[i]=exped_f[i]/fsum
127
+ fidelity_weight,weighted_fidelity_weight=[],[]
128
+ if args.dataset in ['south',"north","flu"]:
129
+ for i in range(num_of_fidelities):
130
+ fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
131
+ weighted_fidelity_weight+=[0]
132
+ elif args.dataset in ["syn"]:
133
+ fidelity_weight=[torch.tensor(1,dtype=torch.float32).requires_grad_(),torch.tensor(0.0,dtype=torch.float32).requires_grad_()]
134
+ for i in range(num_of_fidelities):
135
+ # fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
136
+ weighted_fidelity_weight+=[0]
137
+ reweight_fidelity()
138
+ if args.dataset in ['south',"north"]:
139
+ x_in=30
140
+ elif args.dataset in ['flu']:
141
+ x_in=0
142
+ elif args.dataset=='syn':
143
+ x_in=1
144
+ else:
145
+ raise Exception('Dataset not recognized.')
146
+ if args.model=="GFusion":
147
+ GFusion_model=GFusion(h_channel=args.h_ch,input_featuresize=x_in,\
148
+ localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share)
149
+ GFusion_model.to(device)
150
+ optimizer = torch.optim.Adam( list(GFusion_model.parameters()), lr=args.lr)
151
+ if args.fidelity_train:
152
+ optimizer2 = torch.optim.Adam(fidelity_weight, lr=optimizer.param_groups[0]['lr']*10)
153
+ scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer2, factor=0.1, patience=args.patience, min_lr=1e-8)
154
+ scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
155
+
156
+ def train(GFusion_model):
157
+ epochloss=0
158
+ y_hat, y_true,y_hat_logit = [], [], []
159
+ optimizer.zero_grad()
160
+ if args.fidelity_train: optimizer2.zero_grad()
161
+ if args.model=="GFusion":
162
+ GFusion_model.train()
163
+ for i,data in enumerate(train_Loader):
164
+ if num_of_fidelities==2:
165
+ x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
166
+ x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
167
+ if args.dataset=='syn':
168
+ x1[:,1]=x1[:,1]+x1[:,2]
169
+ x1=x1[:,[0,1,3,4]]
170
+ x2[:,1]=x2[:,1]+x2[:,2]
171
+ x2=x2[:,[0,1,3,4]]
172
+ x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
173
+ x2[x2[:,0]>6666,0]=6666
174
+ # edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
175
+ datasource=data[0].datasource
176
+ Y = target1
177
+ assert(torch.equal(target1,target2))
178
+ Y[Y>6666]=6666
179
+ x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
180
+ x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
181
+ """
182
+ triplets are not the same for graphs when training
183
+ """
184
+ num_nodes1=x1.shape[0]
185
+ num_nodes2=x2.shape[0]
186
+ edge_index_2rd_1, _, _, edx_2nd_1 = triplets(edge_index1, num_nodes1)
187
+ edge_index_2rd_2, _, _, edx_2nd_2 = triplets(edge_index2, num_nodes2)
188
+
189
+ pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
190
+ [edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
191
+ pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
192
+
193
+ if args.dataset=='syn':
194
+ pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
195
+ else:
196
+ pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
197
+
198
+ loss_weight= [weighted_fidelity_weight[i] for i in datasource]
199
+ loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1),loss_weight)
200
+ """
201
+ record predictions
202
+ """
203
+ y_hat += list(pred.detach().numpy().reshape(-1))
204
+ y_true += list(Y.detach().numpy().reshape(-1))
205
+ loss=loss1
206
+ loss.backward()
207
+ epochloss+=loss
208
+ optimizer.step()
209
+ optimizer.zero_grad()
210
+ if args.fidelity_train:
211
+ optimizer2.step()
212
+ optimizer2.zero_grad()
213
+ reweight_fidelity()
214
+ return epochloss.item()/len(train_Loader),y_hat, y_true
215
+
216
+ def test(loader,GFusion_model,fidelity_weight):
217
+ if not args.single_high:
218
+ weighted_fidelity_weight=[i.detach() for i in fidelity_weight]
219
+ exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
220
+ fsum=sum(exped_f)
221
+ for i in range(num_of_fidelities):
222
+ weighted_fidelity_weight[i]=exped_f[i]/fsum
223
+ else:
224
+ weighted_fidelity_weight=[1,0]
225
+ y_hat, y_true,y_hat_logit = [], [], []
226
+ loss_total, pred_num = 0, 0
227
+ GFusion_model.eval()
228
+ for i,data in enumerate(loader):
229
+ if num_of_fidelities==2:
230
+ x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
231
+ x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
232
+ if args.dataset=='syn':
233
+ x1[:,1]=x1[:,1]+x1[:,2]
234
+ x1=x1[:,[0,1,3,4]]
235
+ x2[:,1]=x2[:,1]+x2[:,2]
236
+ x2=x2[:,[0,1,3,4]]
237
+ x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
238
+ x2[x2[:,0]>6666,0]=6666
239
+ # edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
240
+ datasource=data[0].datasource
241
+ Y = target1
242
+ assert(torch.equal(target1,target2))
243
+ Y[Y>6666]=6666
244
+ x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
245
+ x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
246
+
247
+ num_nodes1=x1.shape[0]
248
+ num_nodes2=x2.shape[0]
249
+ edge_index_2rd_1, num_2nd_neighbors_1, edx_1st_1, edx_2nd_1 = triplets(edge_index1, num_nodes1)
250
+ edge_index_2rd_2, num_2nd_neighbors_2, edx_1st_2, edx_2nd_2 = triplets(edge_index2, num_nodes2)
251
+ pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
252
+ [edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
253
+ pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
254
+ with torch.no_grad():
255
+ if args.dataset=='syn':
256
+ pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
257
+ else:
258
+ pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
259
+ assert(all(datasource==0))
260
+ loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1))*weighted_fidelity_weight[0]
261
+ """
262
+ record predictions
263
+ """
264
+ y_hat += list(pred.detach().numpy().reshape(-1))
265
+ y_true += list(Y.detach().numpy().reshape(-1))
266
+ pred_num += len(Y.reshape(-1, 1))
267
+ loss=loss1
268
+ loss_total += loss.detach() * len(Y.reshape(-1, 1))
269
+ return loss_total/pred_num, y_hat, y_true
270
+ if args.loadmodel:
271
+ try:
272
+ suffix='Oct31-11:50:30'
273
+ GFusion_model.load_state_dict(torch.load(os.path.join("save",args.dataset,'model','best_GFusion_model_'+suffix+'.pth')),strict=True)
274
+ best_GFusion_model = copy.deepcopy(GFusion_model)
275
+ except OSError:
276
+ pass
277
+ else:
278
+ best_val_trigger = 1e3
279
+ old_lr=1e3
280
+ suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
281
+ datetime.now().strftime("%d"),
282
+ datetime.now().strftime("%H"),
283
+ datetime.now().strftime("%M"),
284
+ datetime.now().strftime("%S"))
285
+ if args.log:
286
+ writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
287
+ for epoch in range(args.nepoch):
288
+ if args.model in ['GFusion']: train_loss,y_hat, y_true=train(GFusion_model)
289
+ if args.log:
290
+ writer.add_scalar('loss/Train', train_loss, epoch)
291
+ if args.dataset in ['south',"north",'syn','flu']:
292
+ train_mae=mean_absolute_error(y_true, y_hat)
293
+ train_rmse = np.sqrt(mean_squared_error(y_true, y_hat))
294
+ if args.log:
295
+ writer.add_scalar('mae/Train', train_mae, epoch)
296
+ writer.add_scalar('rmse/Train', train_rmse, epoch)
297
+ print(( f"epoch[{epoch:d}] train_loss : {train_loss:.3f} train_mae : {train_mae:.3f} train_rmse : {train_rmse:.3f}" ))
298
+ if args.model in ['GFusion']:
299
+ if args.fidelity_train==True:
300
+ print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
301
+ print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
302
+ if epoch % args.test_per_round == 0:
303
+ if args.model in ['GFusion']:
304
+ val_loss, yhat_val, ytrue_val = test(val_Loader,GFusion_model,fidelity_weight)
305
+ test_loss, yhat_test, ytrue_test = test(test_Loader,GFusion_model,fidelity_weight)
306
+ if args.log:
307
+ writer.add_scalar('loss/val', val_loss, epoch)
308
+ writer.add_scalar('loss/test', test_loss, epoch)
309
+ if args.dataset in ['south',"north",'syn','flu']:
310
+ val_mae=mean_absolute_error(ytrue_val, yhat_val)
311
+ val_rmse = np.sqrt(mean_squared_error(ytrue_val, yhat_val))
312
+ if args.log:
313
+ writer.add_scalar('mae/val', val_mae, epoch)
314
+ writer.add_scalar('rmse/val', val_rmse, epoch)
315
+ print(blue( f"epoch[{epoch:d}] val_mae : {val_mae:.3f} val_rmse : {val_rmse:.3f}" ))
316
+ test_mae = mean_absolute_error(ytrue_test, yhat_test)
317
+ test_rmse = np.sqrt(mean_squared_error(ytrue_test, yhat_test))
318
+ test_var=explained_variance_score(ytrue_test,yhat_test)
319
+ test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
320
+ test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
321
+ if args.log:
322
+ writer.add_scalar('mae/test', test_mae, epoch)
323
+ writer.add_scalar('rmse/test', test_rmse, epoch)
324
+ print(blue( f"epoch[{epoch:d}] test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_Pearsonr: {test_Pearsonr:.3f} test_coefOfDetermination: {test_coefOfDetermination:.3f}" ))
325
+ if args.model in ['GFusion']:
326
+ if args.fidelity_train==True:
327
+ print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
328
+ print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
329
+ val_trigger=val_mae
330
+ if val_trigger < best_val_trigger:
331
+ best_val_trigger = val_trigger
332
+ best_GFusion_model = copy.deepcopy(GFusion_model)
333
+ best_fidelity=copy.deepcopy(fidelity_weight)
334
+ best_info=[epoch,val_trigger]
335
+ """
336
+ update lr when epoch≥30
337
+ """
338
+ if epoch >= 30:
339
+ lr = scheduler.optimizer.param_groups[0]['lr']
340
+ if old_lr!=lr:
341
+ print(red('lr'), epoch, (lr), sep=', ')
342
+ old_lr=lr
343
+ scheduler.step(val_trigger)
344
+ if args.fidelity_train:
345
+ scheduler2.step(val_trigger)
346
+ val_loss, yhat_val, ytrue_val = test(val_Loader,best_GFusion_model,best_fidelity)
347
+ test_loss, yhat_test, ytrue_test = test(test_Loader,best_GFusion_model,best_fidelity)
348
+ if args.dataset in ['south',"north",'syn','flu']:
349
+ val_mae = mean_absolute_error(ytrue_val, yhat_val)
350
+ val_rmse=np.sqrt(mean_squared_error(ytrue_val,yhat_val))
351
+ val_var=explained_variance_score(ytrue_val,yhat_val)
352
+ print(blue( f"best_val val_mae: {val_mae:.3f} val_rmse: {val_rmse:.3f} val_var: {val_var:.3f}" ))
353
+
354
+ test_mae=mean_absolute_error(ytrue_test,yhat_test)
355
+ test_rmse=np.sqrt(mean_squared_error(ytrue_test,yhat_test))
356
+ test_var=explained_variance_score(ytrue_test,yhat_test)
357
+ test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
358
+ test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
359
+ print(blue( f"best_test test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_var: {test_var:.3f}" ))
360
+ if not args.loadmodel:
361
+ """
362
+ save training info and best result
363
+ """
364
+ result_file=os.path.join(info_dir, suffix)
365
+ with open(result_file, 'w') as f:
366
+ print(args.num_neighbors,args.nepoch,sep=' ',file=f)
367
+ print(f"fidelity weight: {best_fidelity[0]:.3f}, {best_fidelity[1]:.3f}",file=f)
368
+ print("Random Seed: ", Seed,file=f)
369
+ if args.dataset in ['south',"north",'syn','flu']:
370
+ print(f"MAE val : {val_mae:.3f}, Test : {test_mae:.3f}", file=f)
371
+ print(f"rmse val : {val_rmse:.3f}, Test : {test_rmse:.3f}", file=f)
372
+ print(f"var val : {val_var:.3f}, Test : {test_var:.3f}", file=f)
373
+ print(f"test_coefOfDetermination: {test_coefOfDetermination:.3f}, test_Pearsonr : {test_Pearsonr:.3f}", file=f)
374
+ print(f"Best info: {best_info}", file=f)
375
+ for i in [[a,getattr(args, a)] for a in args.__dict__]:
376
+ print(i,sep='\n',file=f)
377
+ with open(os.path.join(model_dir,'best_f_weight'+"_"+suffix+".pkl"), 'wb') as handle:
378
+ pickle.dump(fidelity_weight, handle)
379
+ torch.save(best_GFusion_model.state_dict(), os.path.join(model_dir,'best_GFusion_model'+"_"+suffix+'.pth') )
380
+ print("done")
381
+
382
+ if __name__ == '__main__':
383
+ args = get_args()
384
+ if not os.path.exists(args.save_dir):
385
+ os.makedirs(args.save_dir,exist_ok=True)
386
+ tensorboard_dir=os.path.join(args.save_dir,'log')
387
+ if not os.path.exists(tensorboard_dir):
388
+ os.makedirs(tensorboard_dir,exist_ok=True)
389
+ model_dir=os.path.join(args.save_dir,'model')
390
+ if not os.path.exists(model_dir):
391
+ os.makedirs(model_dir,exist_ok=True)
392
+ info_dir=os.path.join(args.save_dir,'info')
393
+ if not os.path.exists(info_dir):
394
+ os.makedirs(info_dir,exist_ok=True)
395
+ Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
396
+ print("Random Seed: ", Seed)
397
+ random.seed(Seed)
398
+ torch.manual_seed(Seed)
399
+ np.random.seed(Seed)
400
+ flag=0
401
+ if args.dataset in ['south',"north",'syn',"flu"]:
402
+ graphs1,graphs2=dataset.load_point(args.dataset,args.num_neighbors,[False,200,500])
403
+ np.random.shuffle(graphs1)
404
+ val_test_split = int(np.around( 2 / 10 * len(graphs1) ))
405
+ train_val_split = int(len(graphs1)-2*val_test_split)
406
+ if args.single_high:
407
+ train_graphs = graphs1[:train_val_split]
408
+ else:
409
+ train_graphs = graphs1[:train_val_split]+graphs2
410
+ val_graphs = graphs1[train_val_split:train_val_split+val_test_split]
411
+ test_graphs = graphs1[train_val_split+val_test_split:]
412
+
413
+ np.random.shuffle(train_graphs)
414
+ train_Loader=DataLoader(train_graphs, batch_size=args.batchSize)
415
+ val_Loader=DataLoader(val_graphs, batch_size=args.batchSize)
416
+ test_Loader=DataLoader(test_graphs, batch_size=args.batchSize)
417
+ print(f"train_pair_num: {len(train_graphs)}, val_pair_num: {len(val_graphs)}, test_pair_num: {len(test_graphs)}")
418
+ else:
419
+ raise Exception('Dataset not recognized.')
420
+ main(args,train_Loader,val_Loader,test_Loader)
421
+
utils/utils.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import networkx as nx
4
+ from networkx.utils import UnionFind
5
+
6
+ from typing import Optional
7
+ import torch
8
+ from torch import Tensor
9
+
10
+ from torch_sparse import SparseTensor
11
+ from scipy.sparse import csr_matrix
12
+ from math import pi as PI
13
+ import torch.nn.functional as F
14
+ def unique(sequence):
15
+ seen = set()
16
+ return [x for x in sequence if not (x in seen or seen.add(x))]
17
+ def pos2key(pos):
18
+ pos=pos.reshape(-1)
19
+ key="{:08.4f}".format(pos[0])+'_'+"{:08.4f}".format(pos[1])
20
+ return key
21
+ def get_angle(v1: Tensor, v2: Tensor):
22
+ if v1.shape[1]==2:
23
+ v1=F.pad(v1, (0, 1))
24
+ if v2.shape[1]==2:
25
+ v2= F.pad(v2, (0, 1))
26
+ return torch.atan2(
27
+ torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
28
+ class GaussianSmearing(torch.nn.Module):
29
+ def __init__(self, start=-PI, stop=PI, num_gaussians=12):
30
+ super(GaussianSmearing, self).__init__()
31
+ offset = torch.linspace(start, stop, num_gaussians)
32
+ self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
33
+ self.register_buffer("offset", offset)
34
+
35
+ def forward(self, dist):
36
+ dist = dist.view(-1, 1) - self.offset.view(1, -1)
37
+ return torch.exp(self.coeff * torch.pow(dist, 2))
38
+
39
+ def triplets(edge_index, num_nodes):
40
+ row, col = edge_index
41
+
42
+ value = torch.arange(row.size(0), device=row.device)
43
+ adj_t = SparseTensor(row=row, col=col, value=value,
44
+ sparse_sizes=(num_nodes, num_nodes))
45
+ adj_t_row = adj_t[col]
46
+ num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
47
+
48
+ idx_i = row.repeat_interleave(num_triplets)
49
+ idx_j = col.repeat_interleave(num_triplets)
50
+ edx_1st = value.repeat_interleave(num_triplets)
51
+ idx_k = adj_t_row.storage.col()
52
+ edx_2nd = adj_t_row.storage.value()
53
+ mask1 = (idx_i == idx_k) & (idx_j != idx_i)
54
+ mask2 = (idx_i == idx_j) & (idx_j != idx_k)
55
+ mask3 = (idx_j == idx_k) & (idx_i != idx_k)
56
+ mask = ~(mask1 | mask2 | mask3)
57
+ idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
58
+
59
+ num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]
60
+
61
+ return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd
62
+
63
+
64
+ if __name__ == '__main__':
65
+ 1
66
+
67
+