Dzy6
commited on
Commit
•
c7995e9
0
Parent(s):
init
Browse files- README.md +3 -0
- data/north/column +7 -0
- data/south/column +7 -0
- dataset.py +183 -0
- model.py +386 -0
- requirements.txt +300 -0
- run.sh +4 -0
- train.py +421 -0
- utils/utils.py +67 -0
README.md
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
# KDD24 Self-consistent Deep Geometric Learning for Heterogeneous Multi-source Spatial Point Data Prediction
|
2 |
+
|
3 |
+
data is on [dropbox](https://www.dropbox.com/sh/fi5bsxqeuz46h6l/AABSkN6cav7omgvgATX1cs6ga?dl=0)
|
data/north/column
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
['DM8HA' 'Date.Local' 'relative_humidity_max' 'air_temperature_min'
|
2 |
+
'precipitation' 'air_temperature_max' 'wind_direction' 'solar_radiation'
|
3 |
+
'relative_humidity_min' 'wind_speed' 'elevation' 'CO' 'NH3' 'NOX' 'SO2'
|
4 |
+
'VOC' 'PM25-PRI' 'PM10-PRI' 'population_density_county' 'open_water'
|
5 |
+
'developed' 'bareRock_sand_clay' 'd_forest' 'e_forest' 'm_forest' 'shrub'
|
6 |
+
'grassland' 'pasture' 'crops' 'w_wetlands' 'eh_wetlands' 'cmaq'
|
7 |
+
'Latitude' 'Longitude']
|
data/south/column
ADDED
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
['DM8HA' 'Date.Local' 'relative_humidity_max' 'air_temperature_min'
|
2 |
+
'precipitation' 'air_temperature_max' 'wind_direction' 'solar_radiation'
|
3 |
+
'relative_humidity_min' 'wind_speed' 'elevation' 'CO' 'NH3' 'NOX' 'SO2'
|
4 |
+
'VOC' 'PM25-PRI' 'PM10-PRI' 'population_density_county' 'open_water'
|
5 |
+
'developed' 'bareRock_sand_clay' 'd_forest' 'e_forest' 'm_forest' 'shrub'
|
6 |
+
'grassland' 'pasture' 'crops' 'w_wetlands' 'eh_wetlands' 'cmaq'
|
7 |
+
'Latitude' 'Longitude']
|
dataset.py
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import print_function
|
2 |
+
import torch.utils.data as data
|
3 |
+
import os
|
4 |
+
import os.path
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import pandas as pd
|
8 |
+
import sys
|
9 |
+
from torch_geometric.nn import knn_graph
|
10 |
+
from torch_geometric.data import Data
|
11 |
+
from torch_geometric.loader import DataLoader
|
12 |
+
from torch_geometric.utils import add_self_loops
|
13 |
+
from torch_geometric.data.collate import collate
|
14 |
+
from torch_geometric.data.separate import separate
|
15 |
+
import pickle
|
16 |
+
import time
|
17 |
+
|
18 |
+
from torch_geometric.data.data import BaseData
|
19 |
+
from torch_geometric.data.storage import BaseStorage
|
20 |
+
from typing import Any
|
21 |
+
def mycollate(data_list):
|
22 |
+
r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects
|
23 |
+
to the internal storage format of
|
24 |
+
:class:`~torch_geometric.data.InMemoryDataset`."""
|
25 |
+
if len(data_list) == 1:
|
26 |
+
return data_list[0], None
|
27 |
+
data, slices, _ = collate(
|
28 |
+
data_list[0].__class__,
|
29 |
+
data_list=data_list,
|
30 |
+
increment=False,
|
31 |
+
add_batch=False,
|
32 |
+
)
|
33 |
+
return data, slices
|
34 |
+
def myseparate(cls, batch: BaseData, idx: int, slice_dict: Any) -> BaseData:
|
35 |
+
data = cls().stores_as(batch)
|
36 |
+
# We iterate over each storage object and recursively separate all its attributes:
|
37 |
+
for batch_store, data_store in zip(batch.stores, data.stores):
|
38 |
+
attrs = set(batch_store.keys())
|
39 |
+
for attr in attrs:
|
40 |
+
slices = slice_dict[attr]
|
41 |
+
data_store[attr] = _separate(attr, batch_store[attr], idx, slices,
|
42 |
+
batch, batch_store)
|
43 |
+
return data
|
44 |
+
def _separate(
|
45 |
+
key: str,
|
46 |
+
value: Any,
|
47 |
+
idx: int,
|
48 |
+
slices: Any,
|
49 |
+
batch: BaseData,
|
50 |
+
store: BaseStorage,
|
51 |
+
) :
|
52 |
+
# Narrow a `torch.Tensor` based on `slices`.
|
53 |
+
key = str(key)
|
54 |
+
cat_dim = batch.__cat_dim__(key, value, store)
|
55 |
+
start, end = int(slices[idx]), int(slices[idx + 1])
|
56 |
+
value = value.narrow(cat_dim or 0, start, end - start)
|
57 |
+
return value
|
58 |
+
|
59 |
+
def load_point(datasetname="south",k=5,small=[False,50,100]):
|
60 |
+
"""
|
61 |
+
load point and build graph pairs
|
62 |
+
"""
|
63 |
+
print("loading")
|
64 |
+
time1=time.time()
|
65 |
+
if small[0]:
|
66 |
+
print("small south dataset k=5")
|
67 |
+
datasetname="south"
|
68 |
+
k=5
|
69 |
+
filename=os.path.join("data",datasetname,datasetname+f'_{k}.pt')
|
70 |
+
[data_graphs1,slices_graphs1,data_graphs2,slices_graphs2]=torch.load(filename)
|
71 |
+
flattened_list_graphs1 = [myseparate(cls=data_graphs1.__class__, batch=data_graphs1,idx=i,slice_dict=slices_graphs1) for i in range(small[1]*2)]
|
72 |
+
flattened_list_graphs2 = [myseparate(cls=data_graphs2.__class__, batch=data_graphs2,idx=i,slice_dict=slices_graphs2) for i in range(small[2]*2)]
|
73 |
+
unflattened_list_graphs1= [flattened_list_graphs1[n:n+2] for n in range(0, len(flattened_list_graphs1), 2)]
|
74 |
+
unflattened_list_graphs2= [flattened_list_graphs2[n:n+2] for n in range(0, len(flattened_list_graphs2), 2)]
|
75 |
+
print(f"Load data used {time.time()-time1:.1f} seconds")
|
76 |
+
return unflattened_list_graphs1,unflattened_list_graphs2
|
77 |
+
return process(datasetname,k)
|
78 |
+
def process(datasetname="south",k=5):
|
79 |
+
time1=time.time()
|
80 |
+
"""
|
81 |
+
build graph pairs
|
82 |
+
"""
|
83 |
+
point_path= os.path.join("data",datasetname,datasetname+".pkl")
|
84 |
+
with open(point_path, 'rb') as f:
|
85 |
+
data = pickle.load(f)
|
86 |
+
graphs1=[]
|
87 |
+
graphs2=[]
|
88 |
+
for day in data:
|
89 |
+
day_d1=day[0]
|
90 |
+
day_d2=day[1]
|
91 |
+
assert(len(day_d1)<len(day_d2))
|
92 |
+
pos1=day_d1[:,-2:]
|
93 |
+
edge_index1=knn_graph(pos1,k=k)
|
94 |
+
pos2=day_d2[:,-2:]
|
95 |
+
edge_index2=knn_graph(pos2,k=k)
|
96 |
+
"""
|
97 |
+
iterately mask point in day_d1, the high fidelity data, to build high fidelity graphs, which share the same structure
|
98 |
+
"""
|
99 |
+
for i in range(day_d1.shape[0]):
|
100 |
+
day_d1_copy=day_d1.clone().detach()
|
101 |
+
target=day_d1[i,0]
|
102 |
+
day_d1_copy[i,0]=0
|
103 |
+
target_index=torch.tensor(i,dtype=torch.long)
|
104 |
+
is_source = torch.ones(day_d1.shape[0] ,dtype=torch.bool)
|
105 |
+
is_source[i]=False
|
106 |
+
graph1=Data(x=day_d1_copy,pos=pos1,edge_index=edge_index1,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
|
107 |
+
"""
|
108 |
+
build pairing low fidelity graphs, which add the masked point in day_d1, so structure is changing
|
109 |
+
"""
|
110 |
+
day_plus2=torch.cat([day_d1_copy[i][None,:],day_d2])
|
111 |
+
pos_plus2=day_plus2[:,-2:]
|
112 |
+
edge_index_plus2=knn_graph(pos_plus2,k=k)
|
113 |
+
is_source = torch.ones(day_d2.shape[0]+1 ,dtype=torch.bool)
|
114 |
+
is_source[0]=False
|
115 |
+
graph2=Data(x=day_plus2,pos=pos_plus2,edge_index=edge_index_plus2,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(0,dtype=torch.long)[None])
|
116 |
+
graphs1.append([graph1,graph2])
|
117 |
+
"""
|
118 |
+
iterately mask point in day_d2, the low fidelity data, to build low fidelity graphs, which share the same structure
|
119 |
+
"""
|
120 |
+
for i in range(day_d2.shape[0]):
|
121 |
+
day_d2_copy=day_d2.clone().detach()
|
122 |
+
target=day_d2[i,0]
|
123 |
+
day_d2_copy[i,0]=0
|
124 |
+
target_index=torch.tensor(i,dtype=torch.long)
|
125 |
+
is_source = torch.ones(day_d2.shape[0] ,dtype=torch.bool)
|
126 |
+
is_source[i]=False
|
127 |
+
graph2=Data(x=day_d2_copy,pos=pos2,edge_index=edge_index2,target=target[None],target_index=target_index[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
|
128 |
+
"""
|
129 |
+
build pairing high fidelity graphs, which add the masked point in day_d2, so structure is changing
|
130 |
+
"""
|
131 |
+
day_plus1=torch.cat([day_d2_copy[i][None,:],day_d1])
|
132 |
+
pos_plus1=day_plus1[:,-2:]
|
133 |
+
edge_index_plus1=knn_graph(pos_plus1,k=k)
|
134 |
+
is_source = torch.ones(day_d1.shape[0]+1 ,dtype=torch.bool)
|
135 |
+
is_source[0]=False
|
136 |
+
graph1=Data(x=day_plus1,pos=pos_plus1,edge_index=edge_index_plus1,target=target[None],target_index=torch.tensor(0,dtype=torch.long)[None],is_source=is_source,datasource=torch.tensor(1,dtype=torch.long)[None])
|
137 |
+
graphs2.append([graph1,graph2])
|
138 |
+
np.random.shuffle(graphs1)
|
139 |
+
np.random.shuffle(graphs2)
|
140 |
+
return [graphs1,graphs2]
|
141 |
+
|
142 |
+
class MergeNeighborDataset(torch.utils.data.Dataset):
|
143 |
+
""" Customized dataset for each domain"""
|
144 |
+
def __init__(self,X):
|
145 |
+
self.X = X # set data
|
146 |
+
def __len__(self):
|
147 |
+
return len(self.X) # return length
|
148 |
+
def __getitem__(self, idx):
|
149 |
+
return self.X[idx]
|
150 |
+
def kneighbor_point(datasetname="south",k=1,daily=False):
|
151 |
+
"""
|
152 |
+
build k neighbor pairing
|
153 |
+
"""
|
154 |
+
ranking_path= os.path.join("data",datasetname,datasetname+"_ranking.pkl")
|
155 |
+
with open(ranking_path, 'rb') as f:
|
156 |
+
rankings = pickle.load(f)
|
157 |
+
point_path= os.path.join("data",datasetname,datasetname+".pkl")
|
158 |
+
with open(point_path, 'rb') as f:
|
159 |
+
days = pickle.load(f)
|
160 |
+
samples=[]
|
161 |
+
for i in range(len(days)):
|
162 |
+
day_d1=days[i][0]
|
163 |
+
day_d2=days[i][1]
|
164 |
+
ranking=rankings[i]
|
165 |
+
"""
|
166 |
+
iterately get point in day_d1, the high fidelity data, to build samples
|
167 |
+
"""
|
168 |
+
sample1 = []
|
169 |
+
for j in range(day_d1.shape[0]):
|
170 |
+
point1=day_d1[j]
|
171 |
+
point1_neighbors=day_d2[ranking[j,:k]]
|
172 |
+
point1_neighbor=torch.mean(point1_neighbors,axis=0)
|
173 |
+
sample1.append([point1,point1_neighbor])
|
174 |
+
if daily:
|
175 |
+
samples.append(sample1)
|
176 |
+
else:
|
177 |
+
samples.extend(sample1)
|
178 |
+
if not daily:
|
179 |
+
return [samples]
|
180 |
+
return samples
|
181 |
+
|
182 |
+
if __name__ == '__main__':
|
183 |
+
1
|
model.py
ADDED
@@ -0,0 +1,386 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.parallel
|
4 |
+
import torch.utils.data
|
5 |
+
import numpy as np
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torch.nn import Parameter
|
8 |
+
|
9 |
+
from torch_geometric.nn.dense.linear import Linear
|
10 |
+
from torch_geometric.nn.conv import MessagePassing
|
11 |
+
from torch_geometric.utils import softmax
|
12 |
+
# from dataset import
|
13 |
+
from torch_geometric.nn.inits import glorot, zeros
|
14 |
+
|
15 |
+
from torch_scatter import scatter
|
16 |
+
from utils.utils import triplets,get_angle,GaussianSmearing
|
17 |
+
from torch.nn import ModuleList
|
18 |
+
from math import pi as PI
|
19 |
+
import math
|
20 |
+
|
21 |
+
"""
|
22 |
+
The theory based Grid cell spatial relation encoder,
|
23 |
+
See https://openreview.net/forum?id=Syx0Mh05YQ
|
24 |
+
Learning Grid Cells as Vector Representation of Self-Position Coupled with Matrix Representation of Self-Motion
|
25 |
+
"""
|
26 |
+
def _cal_freq_list(freq_init, frequency_num, max_radius, min_radius):
|
27 |
+
if freq_init == "random":
|
28 |
+
# the frequence we use for each block, alpha in paper
|
29 |
+
# freq_list shape: (frequency_num)
|
30 |
+
freq_list = np.random.random(size=[frequency_num]) * max_radius
|
31 |
+
elif freq_init == "geometric":
|
32 |
+
# freq_list = []
|
33 |
+
# for cur_freq in range(frequency_num):
|
34 |
+
# base = 1.0/(np.power(max_radius, cur_freq*1.0/(frequency_num-1)))
|
35 |
+
# freq_list.append(base)
|
36 |
+
|
37 |
+
# freq_list = np.asarray(freq_list)
|
38 |
+
|
39 |
+
log_timescale_increment = (math.log(float(max_radius) / float(min_radius)) /
|
40 |
+
(frequency_num*1.0 - 1))
|
41 |
+
|
42 |
+
timescales = min_radius * np.exp(
|
43 |
+
np.arange(frequency_num).astype(float) * log_timescale_increment)
|
44 |
+
|
45 |
+
freq_list = 1.0/timescales
|
46 |
+
|
47 |
+
return freq_list
|
48 |
+
class TheoryGridCellSpatialRelationEncoder(nn.Module):
|
49 |
+
"""
|
50 |
+
Given a list of (deltaX,deltaY), encode them using the position encoding function
|
51 |
+
|
52 |
+
"""
|
53 |
+
def __init__(self, spa_embed_dim, coord_dim = 2, frequency_num = 16,
|
54 |
+
max_radius = 10000, min_radius = 1000, freq_init = "geometric", ffn = None):
|
55 |
+
"""
|
56 |
+
Args:
|
57 |
+
spa_embed_dim: the output spatial relation embedding dimention
|
58 |
+
coord_dim: the dimention of space, 2D, 3D, or other
|
59 |
+
frequency_num: the number of different sinusoidal with different frequencies/wavelengths
|
60 |
+
max_radius: the largest context radius this model can handle
|
61 |
+
"""
|
62 |
+
super(TheoryGridCellSpatialRelationEncoder, self).__init__()
|
63 |
+
self.frequency_num = frequency_num
|
64 |
+
self.coord_dim = coord_dim
|
65 |
+
self.max_radius = max_radius
|
66 |
+
self.min_radius = min_radius
|
67 |
+
self.spa_embed_dim = spa_embed_dim
|
68 |
+
self.freq_init = freq_init
|
69 |
+
|
70 |
+
# the frequence we use for each block, alpha in paper
|
71 |
+
self.cal_freq_list()
|
72 |
+
self.cal_freq_mat()
|
73 |
+
|
74 |
+
# there unit vectors which is 120 degree apart from each other
|
75 |
+
self.unit_vec1 = np.asarray([1.0, 0.0]) # 0
|
76 |
+
self.unit_vec2 = np.asarray([-1.0/2.0, math.sqrt(3)/2.0]) # 120 degree
|
77 |
+
self.unit_vec3 = np.asarray([-1.0/2.0, -math.sqrt(3)/2.0]) # 240 degree
|
78 |
+
|
79 |
+
|
80 |
+
self.input_embed_dim = self.cal_input_dim()
|
81 |
+
self.ffn = ffn
|
82 |
+
|
83 |
+
def cal_freq_list(self):
|
84 |
+
self.freq_list = _cal_freq_list(self.freq_init, self.frequency_num, self.max_radius, self.min_radius)
|
85 |
+
|
86 |
+
def cal_freq_mat(self):
|
87 |
+
# freq_mat shape: (frequency_num, 1)
|
88 |
+
freq_mat = np.expand_dims(self.freq_list, axis = 1)
|
89 |
+
# self.freq_mat shape: (frequency_num, 6)
|
90 |
+
self.freq_mat = np.repeat(freq_mat, 6, axis = 1)
|
91 |
+
|
92 |
+
def cal_input_dim(self):
|
93 |
+
# compute the dimention of the encoded spatial relation embedding
|
94 |
+
return int(6 * self.frequency_num)
|
95 |
+
|
96 |
+
|
97 |
+
def make_input_embeds(self, coords):
|
98 |
+
if type(coords) == np.ndarray:
|
99 |
+
assert self.coord_dim == np.shape(coords)[2]
|
100 |
+
coords = list(coords)
|
101 |
+
elif type(coords) == list:
|
102 |
+
assert self.coord_dim == len(coords[0][0])
|
103 |
+
elif type(coords) == torch.Tensor:
|
104 |
+
assert self.coord_dim == (coords.shape)[2]
|
105 |
+
coords=coords.detach().cpu().numpy()
|
106 |
+
else:
|
107 |
+
raise Exception("Unknown coords data type for GridCellSpatialRelationEncoder")
|
108 |
+
|
109 |
+
# (batch_size, num_context_pt, coord_dim)
|
110 |
+
coords_mat = np.asarray(coords).astype(float)
|
111 |
+
batch_size = coords_mat.shape[0]
|
112 |
+
num_context_pt = coords_mat.shape[1]
|
113 |
+
|
114 |
+
# compute the dot product between [deltaX, deltaY] and each unit_vec
|
115 |
+
# (batch_size, num_context_pt, 1)
|
116 |
+
angle_mat1 = np.expand_dims(np.matmul(coords_mat, self.unit_vec1), axis = -1)
|
117 |
+
# (batch_size, num_context_pt, 1)
|
118 |
+
angle_mat2 = np.expand_dims(np.matmul(coords_mat, self.unit_vec2), axis = -1)
|
119 |
+
# (batch_size, num_context_pt, 1)
|
120 |
+
angle_mat3 = np.expand_dims(np.matmul(coords_mat, self.unit_vec3), axis = -1)
|
121 |
+
|
122 |
+
# (batch_size, num_context_pt, 6)
|
123 |
+
angle_mat = np.concatenate([angle_mat1, angle_mat1, angle_mat2, angle_mat2, angle_mat3, angle_mat3], axis = -1)
|
124 |
+
# (batch_size, num_context_pt, 1, 6)
|
125 |
+
angle_mat = np.expand_dims(angle_mat, axis = -2)
|
126 |
+
# (batch_size, num_context_pt, frequency_num, 6)
|
127 |
+
angle_mat = np.repeat(angle_mat, self.frequency_num, axis = -2)
|
128 |
+
# (batch_size, num_context_pt, frequency_num, 6)
|
129 |
+
angle_mat = angle_mat * self.freq_mat
|
130 |
+
# (batch_size, num_context_pt, frequency_num*6)
|
131 |
+
spr_embeds = np.reshape(angle_mat, (batch_size, num_context_pt, -1))
|
132 |
+
|
133 |
+
# make sinuniod function
|
134 |
+
# sin for 2i, cos for 2i+1
|
135 |
+
# spr_embeds: (batch_size, num_context_pt, frequency_num*6=input_embed_dim)
|
136 |
+
spr_embeds[:, :, 0::2] = np.sin(spr_embeds[:, :, 0::2]) # dim 2i
|
137 |
+
spr_embeds[:, :, 1::2] = np.cos(spr_embeds[:, :, 1::2]) # dim 2i+1
|
138 |
+
|
139 |
+
return spr_embeds
|
140 |
+
|
141 |
+
|
142 |
+
def forward(self, coords):
|
143 |
+
"""
|
144 |
+
Given a list of coords (deltaX, deltaY), give their spatial relation embedding
|
145 |
+
Args:
|
146 |
+
coords: a python list with shape (batch_size, num_context_pt, coord_dim)
|
147 |
+
Return:
|
148 |
+
sprenc: Tensor shape (batch_size, num_context_pt, spa_embed_dim)
|
149 |
+
"""
|
150 |
+
spr_embeds = self.make_input_embeds(coords)
|
151 |
+
|
152 |
+
# spr_embeds: (batch_size, num_context_pt, input_embed_dim)
|
153 |
+
spr_embeds = torch.FloatTensor(spr_embeds)
|
154 |
+
if self.ffn is not None:
|
155 |
+
return self.ffn(spr_embeds)
|
156 |
+
else:
|
157 |
+
return spr_embeds
|
158 |
+
theoryencoder=TheoryGridCellSpatialRelationEncoder(spa_embed_dim=8)
|
159 |
+
|
160 |
+
class GFusion(nn.Module):
|
161 |
+
def __init__(self, h_channel=16,input_featuresize=32,localdepth=2,num_interactions=3,finaldepth=3,num_of_datasources=2,share=True,batchnorm="False"):
|
162 |
+
super(GFusion,self).__init__()
|
163 |
+
self.training=True
|
164 |
+
self.h_channel = h_channel
|
165 |
+
self.input_featuresize=input_featuresize
|
166 |
+
self.localdepth = localdepth
|
167 |
+
self.num_interactions=num_interactions
|
168 |
+
self.finaldepth=finaldepth
|
169 |
+
self.batchnorm = batchnorm
|
170 |
+
self.activation=nn.ReLU()
|
171 |
+
|
172 |
+
num_gaussians=(1,12)
|
173 |
+
self.theta_expansion = GaussianSmearing(-PI, PI, num_gaussians[1])
|
174 |
+
self.mlps_list = ModuleList()
|
175 |
+
if int(share[0])==1:
|
176 |
+
mlp_geo = ModuleList()
|
177 |
+
for i in range(self.localdepth):
|
178 |
+
if i == 0:
|
179 |
+
mlp_geo.append(Linear(sum(num_gaussians), h_channel))
|
180 |
+
else:
|
181 |
+
mlp_geo.append(Linear(h_channel, h_channel))
|
182 |
+
if self.batchnorm == "True":
|
183 |
+
mlp_geo.append(nn.BatchNorm1d(h_channel))
|
184 |
+
mlp_geo.append(self.activation)
|
185 |
+
for i in range(num_of_datasources):
|
186 |
+
self.mlps_list.append(mlp_geo)
|
187 |
+
else:
|
188 |
+
for i in range(num_of_datasources):
|
189 |
+
mlp_geo = ModuleList()
|
190 |
+
for i in range(self.localdepth):
|
191 |
+
if i == 0:
|
192 |
+
mlp_geo.append(Linear(sum(num_gaussians), h_channel))
|
193 |
+
else:
|
194 |
+
mlp_geo.append(Linear(h_channel, h_channel))
|
195 |
+
if self.batchnorm == "True":
|
196 |
+
mlp_geo.append(nn.BatchNorm1d(h_channel))
|
197 |
+
mlp_geo.append(self.activation)
|
198 |
+
self.mlps_list.append(mlp_geo)
|
199 |
+
self.mlps_list_backup = ModuleList()
|
200 |
+
for i in range(num_of_datasources):
|
201 |
+
mlp_geo = ModuleList()
|
202 |
+
for i in range(self.localdepth):
|
203 |
+
if i == 0:
|
204 |
+
mlp_geo.append(Linear(4, h_channel)) # for FN version
|
205 |
+
else:
|
206 |
+
mlp_geo.append(Linear(h_channel, h_channel))
|
207 |
+
if self.batchnorm == "True":
|
208 |
+
mlp_geo.append(nn.BatchNorm1d(h_channel))
|
209 |
+
mlp_geo.append(self.activation)
|
210 |
+
self.mlps_list_backup.append(mlp_geo)
|
211 |
+
self.translinear=Linear(input_featuresize+1, self.h_channel)
|
212 |
+
self.interactions_list = ModuleList()
|
213 |
+
if int(share[1])==1:
|
214 |
+
interactions= ModuleList()
|
215 |
+
for i in range(self.num_interactions):
|
216 |
+
block = SPNN(
|
217 |
+
in_ch=self.input_featuresize,
|
218 |
+
hidden_channels=self.h_channel,
|
219 |
+
activation=self.activation,
|
220 |
+
finaldepth=self.finaldepth,
|
221 |
+
batchnorm=self.batchnorm,
|
222 |
+
num_input_geofeature=self.h_channel
|
223 |
+
)
|
224 |
+
interactions.append(block)
|
225 |
+
for i in range(num_of_datasources):
|
226 |
+
self.interactions_list.append(interactions)
|
227 |
+
else:
|
228 |
+
for i in range(num_of_datasources):
|
229 |
+
interactions= ModuleList()
|
230 |
+
for i in range(self.num_interactions):
|
231 |
+
block = SPNN(
|
232 |
+
in_ch=self.input_featuresize,
|
233 |
+
hidden_channels=self.h_channel,
|
234 |
+
activation=self.activation,
|
235 |
+
finaldepth=self.finaldepth,
|
236 |
+
batchnorm=self.batchnorm,
|
237 |
+
num_input_geofeature=self.h_channel
|
238 |
+
)
|
239 |
+
interactions.append(block)
|
240 |
+
self.interactions_list.append(interactions)
|
241 |
+
self.finalMLP_list = ModuleList()
|
242 |
+
if int(share[2])==1:
|
243 |
+
finalMLP=ModuleList()
|
244 |
+
for i in range(self.finaldepth + 1):
|
245 |
+
finalMLP.append(Linear(self.h_channel, self.h_channel))
|
246 |
+
if self.batchnorm == "True":
|
247 |
+
finalMLP.append(nn.BatchNorm1d(self.h_channel))
|
248 |
+
finalMLP.append(self.activation)
|
249 |
+
finalMLP.append(Linear(self.h_channel, 1))
|
250 |
+
for i in range(num_of_datasources):
|
251 |
+
self.finalMLP_list.append(finalMLP)
|
252 |
+
else:
|
253 |
+
for i in range(num_of_datasources):
|
254 |
+
finalMLP=ModuleList()
|
255 |
+
for i in range(self.finaldepth + 1):
|
256 |
+
finalMLP.append(Linear(self.h_channel, self.h_channel))
|
257 |
+
if self.batchnorm == "True":
|
258 |
+
finalMLP.append(nn.BatchNorm1d(self.h_channel))
|
259 |
+
finalMLP.append(self.activation)
|
260 |
+
finalMLP.append(Linear(self.h_channel, 1))
|
261 |
+
self.finalMLP_list.append(finalMLP)
|
262 |
+
self.reset_parameters()
|
263 |
+
def reset_parameters(self):
|
264 |
+
for i in range(len(self.mlps_list)):
|
265 |
+
for lin in self.mlps_list[i]:
|
266 |
+
if isinstance(lin, Linear):
|
267 |
+
torch.nn.init.xavier_uniform_(lin.weight)
|
268 |
+
lin.bias.data.fill_(0)
|
269 |
+
for i in range(len(self.interactions_list)):
|
270 |
+
for block in self.interactions_list[i]:
|
271 |
+
block.reset_parameters()
|
272 |
+
for finalMLP in self.finalMLP_list:
|
273 |
+
for lin in finalMLP:
|
274 |
+
if isinstance(lin, Linear):
|
275 |
+
torch.nn.init.xavier_uniform_(lin.weight)
|
276 |
+
lin.bias.data.fill_(0)
|
277 |
+
|
278 |
+
def single_forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep,datasource_idx):
|
279 |
+
distances={}
|
280 |
+
thetas={}
|
281 |
+
if edge_rep:
|
282 |
+
i, j, k = edge_index_2rd
|
283 |
+
distances[1]=(coords[edge_index[0]] - coords[edge_index[1]]).norm(p=2, dim=1)
|
284 |
+
theta_ijk = get_angle(coords[j] - coords[i], coords[k] - coords[j])
|
285 |
+
v1 = torch.cross(F.pad(coords[j] - coords[i],(0,1)), F.pad(coords[k] - coords[j],(0,1)), dim=1)[...,2]
|
286 |
+
flag = torch.sign((v1))
|
287 |
+
flag[flag==0]=-1
|
288 |
+
thetas[1] = scatter(theta_ijk*flag ,edx_2nd,dim=0,dim_size=edge_index.shape[1],reduce='min')
|
289 |
+
thetas[1]=self.theta_expansion(thetas[1])
|
290 |
+
geo_encoding_1st=distances[1][:,None]
|
291 |
+
geo_encoding_1st[geo_encoding_1st==0]=1E-10
|
292 |
+
geo_encoding_1st=torch.pow(geo_encoding_1st,-1)
|
293 |
+
geo_encoding_2nd = thetas[1]
|
294 |
+
geo_encoding=torch.cat([geo_encoding_1st,geo_encoding_2nd],dim=-1)
|
295 |
+
else:
|
296 |
+
# coords=theoryencoder(coords[None,:])
|
297 |
+
# coords=coords[0].to("cuda")
|
298 |
+
|
299 |
+
coords_j = coords[edge_index[0]]
|
300 |
+
coords_i = coords[edge_index[1]]
|
301 |
+
geo_encoding=torch.cat([coords_j,coords_i],dim=-1)
|
302 |
+
if edge_rep:
|
303 |
+
for lin in self.mlps_list[datasource_idx]:
|
304 |
+
geo_encoding=lin(geo_encoding)
|
305 |
+
else:
|
306 |
+
for lin in self.mlps_list_backup[datasource_idx]:
|
307 |
+
geo_encoding=lin(geo_encoding)
|
308 |
+
geo_encoding=torch.zeros_like(geo_encoding,device=geo_encoding.device,dtype=geo_encoding.dtype)
|
309 |
+
node_feature=self.translinear(input_feature[:,:-2])
|
310 |
+
for interaction in self.interactions_list[datasource_idx]:
|
311 |
+
node_feature = interaction(node_feature,geo_encoding,edge_index,is_source)
|
312 |
+
return node_feature
|
313 |
+
def forward(self, coords,edge_index,edge_index_2rd, edx_2nd,batch,input_feature,is_source,edge_rep):
|
314 |
+
outputs=[]
|
315 |
+
for i in range(len(coords)):
|
316 |
+
output=self.single_forward(coords[i],edge_index[i],edge_index_2rd[i], edx_2nd[i],batch[i],input_feature[i],is_source[i],edge_rep,i)
|
317 |
+
for lin in self.finalMLP_list[i]:
|
318 |
+
output=lin(output)
|
319 |
+
outputs.append(output)
|
320 |
+
return outputs
|
321 |
+
|
322 |
+
class SPNN(torch.nn.Module):
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
in_ch,
|
326 |
+
hidden_channels,
|
327 |
+
activation=torch.nn.ReLU(),
|
328 |
+
finaldepth=3,
|
329 |
+
batchnorm="False",
|
330 |
+
num_input_geofeature=13
|
331 |
+
):
|
332 |
+
super(SPNN, self).__init__()
|
333 |
+
self.activation = activation
|
334 |
+
self.finaldepth = finaldepth
|
335 |
+
self.batchnorm = batchnorm
|
336 |
+
self.num_input_geofeature=num_input_geofeature
|
337 |
+
self.att = Parameter(torch.Tensor(1, hidden_channels),requires_grad=True)
|
338 |
+
|
339 |
+
self.WMLP = ModuleList()
|
340 |
+
for i in range(self.finaldepth + 1):
|
341 |
+
if i == 0:
|
342 |
+
self.WMLP.append(Linear(hidden_channels*2+num_input_geofeature, hidden_channels))
|
343 |
+
else:
|
344 |
+
self.WMLP.append(Linear(hidden_channels, hidden_channels))
|
345 |
+
if self.batchnorm == "True":
|
346 |
+
self.WMLP.append(nn.BatchNorm1d(hidden_channels))
|
347 |
+
self.WMLP.append(self.activation)
|
348 |
+
self.reset_parameters()
|
349 |
+
|
350 |
+
def reset_parameters(self):
|
351 |
+
for lin in self.WMLP:
|
352 |
+
if isinstance(lin, Linear):
|
353 |
+
torch.nn.init.xavier_uniform_(lin.weight)
|
354 |
+
lin.bias.data.fill_(0)
|
355 |
+
glorot(self.att)
|
356 |
+
def forward(self, node_feature,geo_encoding,edge_index,is_source):
|
357 |
+
j, i = edge_index
|
358 |
+
input_feature=node_feature.clone()
|
359 |
+
if node_feature is None:
|
360 |
+
concatenated_vector = geo_encoding
|
361 |
+
else:
|
362 |
+
node_attr_0st = node_feature[i]
|
363 |
+
node_attr_1st = node_feature[j]
|
364 |
+
concatenated_vector = torch.cat(
|
365 |
+
[
|
366 |
+
node_attr_0st,
|
367 |
+
node_attr_1st,
|
368 |
+
geo_encoding,
|
369 |
+
],
|
370 |
+
dim=-1,
|
371 |
+
)
|
372 |
+
x_i = concatenated_vector
|
373 |
+
for lin in self.WMLP:
|
374 |
+
x_i=lin(x_i)
|
375 |
+
input_feature_j=input_feature[edge_index[0]]
|
376 |
+
x_i = F.leaky_relu(x_i)
|
377 |
+
alpha = F.leaky_relu(x_i * self.att).sum(dim=-1)
|
378 |
+
alpha = softmax(alpha, edge_index[1])
|
379 |
+
|
380 |
+
message=input_feature_j * alpha.unsqueeze(-1)
|
381 |
+
out_feature = scatter(message, edge_index[1], dim=0, reduce='add')
|
382 |
+
out_feature=input_feature+out_feature
|
383 |
+
|
384 |
+
return out_feature
|
385 |
+
|
386 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file may be used to create an environment using:
|
2 |
+
# $ conda create --name <env> --file <this file>
|
3 |
+
# platform: linux-64
|
4 |
+
_libgcc_mutex=0.1=conda_forge
|
5 |
+
_openmp_mutex=4.5=2_gnu
|
6 |
+
_py-xgboost-mutex=2.0=cpu_0
|
7 |
+
absl-py=1.3.0=py310h06a4308_0
|
8 |
+
aiohttp=3.8.1=py310h7f8727e_1
|
9 |
+
aiosignal=1.2.0=pyhd3eb1b0_0
|
10 |
+
anyio=3.6.2=pyhd8ed1ab_0
|
11 |
+
argon2-cffi=21.3.0=pyhd8ed1ab_0
|
12 |
+
argon2-cffi-bindings=21.2.0=py310h5764c6d_3
|
13 |
+
asttokens=2.0.5=pyhd3eb1b0_0
|
14 |
+
async-timeout=4.0.2=py310h06a4308_0
|
15 |
+
attrs=21.4.0=pyhd3eb1b0_0
|
16 |
+
autograd=1.5=pyhd8ed1ab_0
|
17 |
+
autopep8=1.6.0=pyhd3eb1b0_1
|
18 |
+
backcall=0.2.0=pyhd3eb1b0_0
|
19 |
+
beautifulsoup4=4.11.2=pyha770c72_0
|
20 |
+
blas=1.0=mkl
|
21 |
+
bleach=6.0.0=pyhd8ed1ab_0
|
22 |
+
blinker=1.4=py310h06a4308_0
|
23 |
+
blosc=1.21.1=h83bc5f7_3
|
24 |
+
boost-cpp=1.74.0=h75c5d50_8
|
25 |
+
bottleneck=1.3.5=py310ha9d4c09_0
|
26 |
+
branca=0.5.0=pyhd8ed1ab_0
|
27 |
+
brotli=1.0.9=h5eee18b_7
|
28 |
+
brotli-bin=1.0.9=h5eee18b_7
|
29 |
+
brotlipy=0.7.0=py310h7f8727e_1002
|
30 |
+
bzip2=1.0.8=h7b6447c_0
|
31 |
+
c-ares=1.18.1=h7f8727e_0
|
32 |
+
ca-certificates=2022.12.7=ha878542_0
|
33 |
+
cachetools=4.2.2=pyhd3eb1b0_0
|
34 |
+
cairo=1.16.0=h19f5f5c_2
|
35 |
+
certifi=2022.12.7=pyhd8ed1ab_0
|
36 |
+
cffi=1.15.1=py310h74dc2b5_0
|
37 |
+
cfitsio=4.1.0=hd9d235c_0
|
38 |
+
charset-normalizer=2.0.4=pyhd3eb1b0_0
|
39 |
+
click=8.0.4=py310h06a4308_0
|
40 |
+
click-plugins=1.1.1=pyhd3eb1b0_0
|
41 |
+
cligj=0.7.2=pyhd3eb1b0_0
|
42 |
+
cryptography=38.0.1=py310h9ce1e76_0
|
43 |
+
cudatoolkit=11.6.0=hecad31d_10
|
44 |
+
curl=7.84.0=h5eee18b_0
|
45 |
+
cycler=0.11.0=pyhd3eb1b0_0
|
46 |
+
dataclasses=0.8=pyh6d0b6a4_7
|
47 |
+
debugpy=1.5.1=py310h295c915_0
|
48 |
+
decorator=5.1.1=pyhd3eb1b0_0
|
49 |
+
defusedxml=0.7.1=pyhd8ed1ab_0
|
50 |
+
entrypoints=0.4=py310h06a4308_0
|
51 |
+
executing=0.8.3=pyhd3eb1b0_0
|
52 |
+
expat=2.4.9=h6a678d5_0
|
53 |
+
ffmpeg=4.2.2=h20bf706_0
|
54 |
+
fftw=3.3.9=h27cfd23_1
|
55 |
+
fiona=1.8.21=py310h60a68a4_2
|
56 |
+
flit-core=3.8.0=pyhd8ed1ab_0
|
57 |
+
folium=0.12.1.post1=pyhd8ed1ab_1
|
58 |
+
font-ttf-dejavu-sans-mono=2.37=hd3eb1b0_0
|
59 |
+
font-ttf-inconsolata=2.001=hcb22688_0
|
60 |
+
font-ttf-source-code-pro=2.030=hd3eb1b0_0
|
61 |
+
font-ttf-ubuntu=0.83=h8b1ccd4_0
|
62 |
+
fontconfig=2.14.0=h8e229c2_0
|
63 |
+
fonts-anaconda=1=h8fa9717_0
|
64 |
+
fonts-conda-ecosystem=1=hd3eb1b0_0
|
65 |
+
fonttools=4.25.0=pyhd3eb1b0_0
|
66 |
+
freetype=2.11.0=h70c0345_0
|
67 |
+
freexl=1.0.6=h27cfd23_0
|
68 |
+
frozenlist=1.2.0=py310h7f8727e_1
|
69 |
+
future=0.18.3=pyhd8ed1ab_0
|
70 |
+
gdal=3.5.1=py310hce6f0df_1
|
71 |
+
geographiclib=1.52=pyhd8ed1ab_0
|
72 |
+
geopandas=0.11.1=pyhd8ed1ab_0
|
73 |
+
geopandas-base=0.11.1=pyha770c72_0
|
74 |
+
geopy=2.2.0=pyhd8ed1ab_0
|
75 |
+
geos=3.11.0=h27087fc_0
|
76 |
+
geotiff=1.7.1=h4fc65e6_3
|
77 |
+
gettext=0.21.0=hf68c758_0
|
78 |
+
giflib=5.2.1=h7b6447c_0
|
79 |
+
glib=2.72.1=h6239696_0
|
80 |
+
glib-tools=2.72.1=h6239696_0
|
81 |
+
gmp=6.2.1=h295c915_3
|
82 |
+
gnutls=3.6.15=he1e5248_0
|
83 |
+
google-auth=2.6.0=pyhd3eb1b0_0
|
84 |
+
google-auth-oauthlib=0.4.4=pyhd3eb1b0_0
|
85 |
+
gpy=1.10.0=py310hde88566_3
|
86 |
+
grpcio=1.42.0=py310hce63b2e_0
|
87 |
+
hdf4=4.2.15=h9772cbc_4
|
88 |
+
hdf5=1.12.1=nompi_h2386368_104
|
89 |
+
icu=70.1=h27087fc_0
|
90 |
+
idna=3.4=py310h06a4308_0
|
91 |
+
importlib-metadata=4.11.3=py310h06a4308_0
|
92 |
+
importlib_resources=5.12.0=pyhd8ed1ab_0
|
93 |
+
intel-openmp=2021.4.0=h06a4308_3561
|
94 |
+
ipykernel=6.15.2=py310h06a4308_0
|
95 |
+
ipython=8.4.0=py310h06a4308_0
|
96 |
+
ipython_genutils=0.2.0=py_1
|
97 |
+
jedi=0.18.1=py310h06a4308_1
|
98 |
+
jinja2=3.1.2=py310h06a4308_0
|
99 |
+
joblib=1.1.1=py310h06a4308_0
|
100 |
+
jpeg=9e=h7f8727e_0
|
101 |
+
json-c=0.16=h5eee18b_0
|
102 |
+
jsonschema=4.17.3=pyhd8ed1ab_0
|
103 |
+
jupyter_client=7.3.4=py310h06a4308_0
|
104 |
+
jupyter_core=4.10.0=py310h06a4308_0
|
105 |
+
jupyter_server=1.23.4=py310h06a4308_0
|
106 |
+
jupyterlab_pygments=0.2.2=pyhd8ed1ab_0
|
107 |
+
kealib=1.4.15=hfe1a663_0
|
108 |
+
keyutils=1.6.1=h166bdaf_0
|
109 |
+
kiwisolver=1.4.2=py310h295c915_0
|
110 |
+
krb5=1.19.3=h3790be6_0
|
111 |
+
lame=3.100=h7b6447c_0
|
112 |
+
lcms2=2.12=h3be6417_0
|
113 |
+
ld_impl_linux-64=2.38=h1181459_1
|
114 |
+
lerc=3.0=h295c915_0
|
115 |
+
libbrotlicommon=1.0.9=h5eee18b_7
|
116 |
+
libbrotlidec=1.0.9=h5eee18b_7
|
117 |
+
libbrotlienc=1.0.9=h5eee18b_7
|
118 |
+
libcurl=7.84.0=h91b91d3_0
|
119 |
+
libdap4=3.20.6=hd7c4107_2
|
120 |
+
libdeflate=1.8=h7f8727e_5
|
121 |
+
libedit=3.1.20210910=h7f8727e_0
|
122 |
+
libev=4.33=h7f8727e_1
|
123 |
+
libffi=3.4.2=h7f98852_5
|
124 |
+
libgcc-ng=12.1.0=h8d9b700_16
|
125 |
+
libgdal=3.5.1=h32640fd_1
|
126 |
+
libgfortran-ng=11.2.0=h00389a5_1
|
127 |
+
libgfortran5=11.2.0=h1234567_1
|
128 |
+
libglib=2.72.1=h2d90d5f_0
|
129 |
+
libgomp=12.1.0=h8d9b700_16
|
130 |
+
libiconv=1.16=h7f8727e_2
|
131 |
+
libidn2=2.3.2=h7f8727e_0
|
132 |
+
libkml=1.3.0=h238a007_1014
|
133 |
+
libnetcdf=4.8.1=nompi_h329d8a1_102
|
134 |
+
libnghttp2=1.46.0=hce63b2e_0
|
135 |
+
libnsl=2.0.0=h5eee18b_0
|
136 |
+
libopus=1.3.1=h7b6447c_0
|
137 |
+
libpng=1.6.37=hbc83047_0
|
138 |
+
libpq=14.5=hd77ab85_0
|
139 |
+
libprotobuf=3.20.1=h4ff587b_0
|
140 |
+
libpysal=4.1.1=py_0
|
141 |
+
librttopo=1.1.0=hf730bdb_11
|
142 |
+
libsodium=1.0.18=h7b6447c_0
|
143 |
+
libspatialindex=1.9.3=h2531618_0
|
144 |
+
libspatialite=5.0.1=h38b5f51_18
|
145 |
+
libsqlite=3.39.3=h753d276_0
|
146 |
+
libssh2=1.10.0=h8f2d780_0
|
147 |
+
libstdcxx-ng=12.1.0=ha89aaad_16
|
148 |
+
libtasn1=4.16.0=h27cfd23_0
|
149 |
+
libtiff=4.4.0=hecacb30_0
|
150 |
+
libunistring=0.9.10=h27cfd23_0
|
151 |
+
libuuid=2.32.1=h7f98852_1000
|
152 |
+
libvpx=1.7.0=h439df22_0
|
153 |
+
libwebp=1.2.4=h11a3e52_0
|
154 |
+
libwebp-base=1.2.4=h5eee18b_0
|
155 |
+
libxcb=1.15=h7f8727e_0
|
156 |
+
libxgboost=1.7.1=cpu_ha3b9936_0
|
157 |
+
libxml2=2.9.14=h22db469_4
|
158 |
+
libzip=1.8.0=h5cef20c_0
|
159 |
+
libzlib=1.2.12=h166bdaf_2
|
160 |
+
littleutils=0.2.2=py_0
|
161 |
+
llvm-openmp=8.0.1=hc9558a2_0
|
162 |
+
lz4-c=1.9.3=h295c915_1
|
163 |
+
mapclassify=2.4.3=pyhd3eb1b0_0
|
164 |
+
markdown=3.3.4=py310h06a4308_0
|
165 |
+
markupsafe=2.1.1=py310h7f8727e_0
|
166 |
+
matplotlib-base=3.5.2=py310hf590b9c_0
|
167 |
+
matplotlib-inline=0.1.6=py310h06a4308_0
|
168 |
+
mgwr=2.1.2=py_0
|
169 |
+
mistune=2.0.5=pyhd8ed1ab_0
|
170 |
+
mkl=2021.4.0=h06a4308_640
|
171 |
+
mkl-service=2.4.0=py310h7f8727e_0
|
172 |
+
mkl_fft=1.3.1=py310hd6ae3a3_0
|
173 |
+
mkl_random=1.2.2=py310h00e6091_0
|
174 |
+
multidict=6.0.2=py310h5eee18b_0
|
175 |
+
munch=2.5.0=pyhd3eb1b0_0
|
176 |
+
munkres=1.1.4=py_0
|
177 |
+
nbclassic=0.5.3=pyhb4ecaf3_3
|
178 |
+
nbclient=0.5.13=pyhd8ed1ab_0
|
179 |
+
nbconvert=7.2.9=pyhd8ed1ab_0
|
180 |
+
nbconvert-core=7.2.9=pyhd8ed1ab_0
|
181 |
+
nbconvert-pandoc=7.2.9=pyhd8ed1ab_0
|
182 |
+
nbformat=5.7.3=pyhd8ed1ab_0
|
183 |
+
ncurses=6.3=h5eee18b_3
|
184 |
+
nest-asyncio=1.5.5=py310h06a4308_0
|
185 |
+
nettle=3.7.3=hbbd107a_1
|
186 |
+
networkx=2.8.4=py310h06a4308_0
|
187 |
+
notebook=6.5.3=pyha770c72_0
|
188 |
+
notebook-shim=0.2.2=pyhd8ed1ab_0
|
189 |
+
nspr=4.33=h295c915_0
|
190 |
+
nss=3.78=h2350873_0
|
191 |
+
numexpr=2.8.3=py310hcea2de6_0
|
192 |
+
numpy=1.23.1=py310h1794996_0
|
193 |
+
numpy-base=1.23.1=py310hcba007f_0
|
194 |
+
oauthlib=3.2.1=py310h06a4308_0
|
195 |
+
ogb=1.3.5=pyhd8ed1ab_0
|
196 |
+
openai=0.28.0=pypi_0
|
197 |
+
openh264=2.1.1=h4ff587b_0
|
198 |
+
openjpeg=2.4.0=h3ad879b_0
|
199 |
+
openssl=1.1.1t=h0b41bf4_0
|
200 |
+
outdated=0.2.2=pyhd8ed1ab_0
|
201 |
+
packaging=21.3=pyhd3eb1b0_0
|
202 |
+
pandas=1.4.3=py310h6a678d5_0
|
203 |
+
pandoc=2.19.2=ha770c72_0
|
204 |
+
pandocfilters=1.5.0=pyhd8ed1ab_0
|
205 |
+
paramz=0.9.5=py_0
|
206 |
+
parso=0.8.3=pyhd3eb1b0_0
|
207 |
+
pcre=8.45=h295c915_0
|
208 |
+
pexpect=4.8.0=pyhd3eb1b0_3
|
209 |
+
pickleshare=0.7.5=pyhd3eb1b0_1003
|
210 |
+
pillow=9.2.0=py310hace64e9_1
|
211 |
+
pip=22.1.2=py310h06a4308_0
|
212 |
+
pixman=0.40.0=h7f8727e_1
|
213 |
+
pkgutil-resolve-name=1.3.10=pyhd8ed1ab_0
|
214 |
+
poppler=22.04.0=h1434ded_1
|
215 |
+
poppler-data=0.4.11=h06a4308_0
|
216 |
+
postgresql=14.5=hfdbbde3_0
|
217 |
+
proj=9.0.1=h93bde94_1
|
218 |
+
prometheus_client=0.16.0=pyhd8ed1ab_0
|
219 |
+
prompt-toolkit=3.0.20=pyhd3eb1b0_0
|
220 |
+
protobuf=3.20.1=py310h295c915_0
|
221 |
+
psutil=5.9.0=py310h5eee18b_0
|
222 |
+
ptyprocess=0.7.0=pyhd3eb1b0_2
|
223 |
+
pure_eval=0.2.2=pyhd3eb1b0_0
|
224 |
+
py-xgboost=1.7.1=cpu_py310hd1aba9c_0
|
225 |
+
pyasn1=0.4.8=pyhd3eb1b0_0
|
226 |
+
pyasn1-modules=0.2.8=py_0
|
227 |
+
pycodestyle=2.8.0=pyhd3eb1b0_0
|
228 |
+
pycparser=2.21=pyhd3eb1b0_0
|
229 |
+
pygments=2.11.2=pyhd3eb1b0_0
|
230 |
+
pyjwt=2.4.0=py310h06a4308_0
|
231 |
+
pyopenssl=22.0.0=pyhd3eb1b0_0
|
232 |
+
pyparsing=3.0.9=py310h06a4308_0
|
233 |
+
pyproj=3.4.0=py310hf94497c_0
|
234 |
+
pyrsistent=0.19.3=py310h1fa729e_0
|
235 |
+
pysocks=1.7.1=py310h06a4308_0
|
236 |
+
python=3.10.6=h582c2e5_0_cpython
|
237 |
+
python-dateutil=2.8.2=pyhd3eb1b0_0
|
238 |
+
python-fastjsonschema=2.16.3=pyhd8ed1ab_0
|
239 |
+
python_abi=3.10=2_cp310
|
240 |
+
pytorch=1.12.1=py3.10_cuda11.6_cudnn8.3.2_0
|
241 |
+
pytorch-mutex=1.0=cuda
|
242 |
+
pytz=2022.1=py310h06a4308_0
|
243 |
+
pyzmq=23.2.0=py310h6a678d5_0
|
244 |
+
readline=8.1.2=h7f8727e_1
|
245 |
+
requests=2.28.1=py310h06a4308_0
|
246 |
+
requests-oauthlib=1.3.0=py_0
|
247 |
+
rsa=4.7.2=pyhd3eb1b0_1
|
248 |
+
rtree=0.9.7=py310h06a4308_1
|
249 |
+
scikit-learn=1.1.3=py310h6a678d5_0
|
250 |
+
scipy=1.9.3=py310hd5efca6_0
|
251 |
+
send2trash=1.8.0=pyhd8ed1ab_0
|
252 |
+
setuptools=63.4.1=py310h06a4308_0
|
253 |
+
shapely=1.8.4=py310h5e49deb_0
|
254 |
+
six=1.16.0=pyhd3eb1b0_1
|
255 |
+
snappy=1.1.9=h295c915_0
|
256 |
+
sniffio=1.3.0=pyhd8ed1ab_0
|
257 |
+
soupsieve=2.3.2.post1=pyhd8ed1ab_0
|
258 |
+
spglm=1.0.8=py_0
|
259 |
+
spreg=1.3.0=pyhd8ed1ab_0
|
260 |
+
sqlite=3.39.2=h5082296_0
|
261 |
+
stack_data=0.2.0=pyhd3eb1b0_0
|
262 |
+
tensorboard=2.10.1=pyhd8ed1ab_0
|
263 |
+
tensorboard-data-server=0.6.0=py310hca6d32c_0
|
264 |
+
tensorboard-plugin-wit=1.8.1=py310h06a4308_0
|
265 |
+
terminado=0.17.1=pyh41d4057_0
|
266 |
+
threadpoolctl=2.2.0=pyh0d69192_0
|
267 |
+
tiledb=2.9.5=h1e4a385_0
|
268 |
+
tinycss2=1.2.1=pyhd8ed1ab_0
|
269 |
+
tk=8.6.12=h1ccaba5_0
|
270 |
+
toml=0.10.2=pyhd3eb1b0_0
|
271 |
+
torch-cluster=1.6.0=pypi_0
|
272 |
+
torch-geometric=2.1.0.post1=pypi_0
|
273 |
+
torch-scatter=2.0.9=pypi_0
|
274 |
+
torch-sparse=0.6.15=pypi_0
|
275 |
+
torch-spline-conv=1.2.1=pypi_0
|
276 |
+
torch-tb-profiler=0.4.0=pypi_0
|
277 |
+
torchaudio=0.12.1=py310_cu116
|
278 |
+
torchvision=0.13.1=py310_cu116
|
279 |
+
tornado=6.2=py310h5eee18b_0
|
280 |
+
tqdm=4.64.0=py310h06a4308_0
|
281 |
+
traitlets=5.1.1=pyhd3eb1b0_0
|
282 |
+
typing_extensions=4.3.0=py310h06a4308_0
|
283 |
+
tzcode=2022c=h166bdaf_0
|
284 |
+
tzdata=2022a=hda174b7_0
|
285 |
+
urllib3=1.26.12=py310h06a4308_0
|
286 |
+
wcwidth=0.2.5=pyhd3eb1b0_0
|
287 |
+
webencodings=0.5.1=py_1
|
288 |
+
websocket-client=1.5.1=pyhd8ed1ab_0
|
289 |
+
werkzeug=2.0.3=pyhd3eb1b0_0
|
290 |
+
wheel=0.37.1=pyhd3eb1b0_0
|
291 |
+
x264=1!157.20191217=h7b6447c_0
|
292 |
+
xerces-c=3.2.3=h55805fa_5
|
293 |
+
xgboost=1.7.1=cpu_py310hd1aba9c_0
|
294 |
+
xyzservices=2022.9.0=py310h06a4308_0
|
295 |
+
xz=5.2.6=h166bdaf_0
|
296 |
+
yarl=1.8.1=py310h5eee18b_0
|
297 |
+
zeromq=4.3.4=h2531618_0
|
298 |
+
zipp=3.8.0=py310h06a4308_0
|
299 |
+
zlib=1.2.12=h7f8727e_2
|
300 |
+
zstd=1.5.2=ha4553b6_0
|
run.sh
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# dataset=north
|
2 |
+
# dataset=south
|
3 |
+
dataset=flu
|
4 |
+
python3 ./train.py --dataset $dataset --manualSeed True --man_seed 5770
|
train.py
ADDED
@@ -0,0 +1,421 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import random
|
4 |
+
import torch
|
5 |
+
import pandas as pd
|
6 |
+
import numpy as np
|
7 |
+
import time
|
8 |
+
import torch.optim as optim
|
9 |
+
|
10 |
+
from matplotlib import cm
|
11 |
+
import matplotlib.pyplot as plt
|
12 |
+
import json
|
13 |
+
from model import GFusion
|
14 |
+
import torch.nn.functional as F
|
15 |
+
from torch_geometric.data import Data
|
16 |
+
from torch_geometric.loader import DataLoader
|
17 |
+
|
18 |
+
from torch_geometric.utils import add_self_loops
|
19 |
+
from torch.nn.functional import softmax
|
20 |
+
from torch_geometric.nn import knn_graph
|
21 |
+
import copy
|
22 |
+
|
23 |
+
torch.autograd.set_detect_anomaly(True)
|
24 |
+
from sklearn.metrics import explained_variance_score,mean_squared_error,mean_absolute_error,r2_score,precision_score,recall_score,f1_score,roc_auc_score,roc_curve, auc
|
25 |
+
from sklearn.feature_selection import r_regression
|
26 |
+
import pickle
|
27 |
+
from utils.utils import triplets,unique,pos2key
|
28 |
+
from torch.utils.tensorboard import SummaryWriter
|
29 |
+
from datetime import datetime
|
30 |
+
import dataset
|
31 |
+
|
32 |
+
def count_parameters(model):
|
33 |
+
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
34 |
+
|
35 |
+
blue = lambda x: '\033[94m' + x + '\033[0m'
|
36 |
+
red = lambda x: '\033[31m' + x + '\033[0m'
|
37 |
+
green = lambda x: '\033[32m' + x + '\033[0m'
|
38 |
+
yellow = lambda x: '\033[33m' + x + '\033[0m'
|
39 |
+
greenline = lambda x: '\033[42m' + x + '\033[0m'
|
40 |
+
yellowline = lambda x: '\033[43m' + x + '\033[0m'
|
41 |
+
|
42 |
+
def get_args():
|
43 |
+
parser = argparse.ArgumentParser()
|
44 |
+
parser.add_argument('--log', type=str, default="True")
|
45 |
+
parser.add_argument('--loadmodel', type=str, default="False")
|
46 |
+
parser.add_argument('--split_dataset', type=str, default="False")
|
47 |
+
parser.add_argument('--model', type=str, default="GFusion")
|
48 |
+
|
49 |
+
# ablation
|
50 |
+
parser.add_argument('--edge_rep', type=str, default="True")
|
51 |
+
parser.add_argument('--single_high', type=str, default="False")
|
52 |
+
parser.add_argument('--fidelity_train', type=str, default="True")
|
53 |
+
parser.add_argument('--fidelity_low_weight', type=float, default=-1.0)
|
54 |
+
parser.add_argument('--share', type=str, default="101")
|
55 |
+
|
56 |
+
parser.add_argument('--dataset', type=str, default='flu')
|
57 |
+
parser.add_argument('--manualSeed', type=str, default="False")
|
58 |
+
parser.add_argument('--man_seed', type=int, default=12345)
|
59 |
+
parser.add_argument('--test_per_round', type=int, default=10)
|
60 |
+
parser.add_argument('--patience', type=int, default=30) #scheduler
|
61 |
+
parser.add_argument('--nepoch', type=int, default=201)
|
62 |
+
parser.add_argument('--lr', type=float, default=1e-3)
|
63 |
+
parser.add_argument('--activation', type=str, default='relu')#'lrelu'
|
64 |
+
parser.add_argument('--batchSize', type=int, default=512)
|
65 |
+
|
66 |
+
parser.add_argument('--num_neighbors', type=int, default=3)
|
67 |
+
parser.add_argument('--regression_loss', type=str, default='l2')
|
68 |
+
|
69 |
+
parser.add_argument('--h_ch', type=int, default=16)
|
70 |
+
parser.add_argument('--localdepth', type=int, default=1) # mlp(distance) mlp(theta) >=1
|
71 |
+
parser.add_argument('--num_interactions', type=int, default=1) #>=1
|
72 |
+
parser.add_argument('--finaldepth', type=int, default=3) # mlp(concat node_attr and geo_encoding)
|
73 |
+
|
74 |
+
args = parser.parse_args()
|
75 |
+
args.log=True if args.log=="True" else False
|
76 |
+
args.loadmodel=True if args.loadmodel=="True" else False
|
77 |
+
args.split_dataset=True if args.split_dataset=="True" else False
|
78 |
+
args.edge_rep=True if args.edge_rep=="True" else False
|
79 |
+
args.single_high=True if args.single_high=="True" else False
|
80 |
+
args.fidelity_train=True if args.fidelity_train=="True" and args.single_high is False and args.fidelity_low_weight==-1.0 else False
|
81 |
+
args.manualSeed=True if args.manualSeed=="True" else False
|
82 |
+
args.save_dir=os.path.join('./save/',args.dataset)
|
83 |
+
return args
|
84 |
+
|
85 |
+
def main(args,train_Loader,val_Loader,test_Loader):
|
86 |
+
if flag:
|
87 |
+
return
|
88 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
89 |
+
measure_Pearsonr=r_regression
|
90 |
+
criterion_l1 = torch.nn.L1Loss() #reduction='sum'
|
91 |
+
criterion_l2 = torch.nn.MSELoss()
|
92 |
+
criterion=criterion_l1 if args.regression_loss=='l1' else criterion_l2
|
93 |
+
if args.model in ['GFusion']:
|
94 |
+
def myL1(pred,true,weight=None,reduction='mean'):
|
95 |
+
loss=(abs(pred-true))
|
96 |
+
num=len(pred)
|
97 |
+
if weight is not None:
|
98 |
+
loss=[weight[i]*loss[i] for i in range(num)]
|
99 |
+
loss=sum(loss)
|
100 |
+
if reduction=='mean':
|
101 |
+
loss=loss/num
|
102 |
+
return loss
|
103 |
+
def myL2(pred,true,weight=None,reduction='mean'):
|
104 |
+
loss=((pred-true)**2)
|
105 |
+
num=len(pred)
|
106 |
+
if weight is not None:
|
107 |
+
loss=[weight[i]*loss[i] for i in range(num)]
|
108 |
+
loss=sum(loss)
|
109 |
+
if reduction=='mean':
|
110 |
+
loss=loss/num
|
111 |
+
return loss
|
112 |
+
criterion=myL1 if args.regression_loss=='l1' else myL2
|
113 |
+
num_of_fidelities=len(train_graphs[0])
|
114 |
+
|
115 |
+
def reweight_fidelity():
|
116 |
+
if args.single_high:
|
117 |
+
weighted_fidelity_weight[0]=1
|
118 |
+
weighted_fidelity_weight[1]=0
|
119 |
+
elif args.fidelity_low_weight!=-1.0:
|
120 |
+
weighted_fidelity_weight[0]=1
|
121 |
+
weighted_fidelity_weight[1]=args.fidelity_low_weight
|
122 |
+
else:
|
123 |
+
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
|
124 |
+
fsum=sum(exped_f)
|
125 |
+
for i in range(num_of_fidelities):
|
126 |
+
weighted_fidelity_weight[i]=exped_f[i]/fsum
|
127 |
+
fidelity_weight,weighted_fidelity_weight=[],[]
|
128 |
+
if args.dataset in ['south',"north","flu"]:
|
129 |
+
for i in range(num_of_fidelities):
|
130 |
+
fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
|
131 |
+
weighted_fidelity_weight+=[0]
|
132 |
+
elif args.dataset in ["syn"]:
|
133 |
+
fidelity_weight=[torch.tensor(1,dtype=torch.float32).requires_grad_(),torch.tensor(0.0,dtype=torch.float32).requires_grad_()]
|
134 |
+
for i in range(num_of_fidelities):
|
135 |
+
# fidelity_weight+=[torch.tensor(1.0/num_of_fidelities,dtype=torch.float32).requires_grad_()]
|
136 |
+
weighted_fidelity_weight+=[0]
|
137 |
+
reweight_fidelity()
|
138 |
+
if args.dataset in ['south',"north"]:
|
139 |
+
x_in=30
|
140 |
+
elif args.dataset in ['flu']:
|
141 |
+
x_in=0
|
142 |
+
elif args.dataset=='syn':
|
143 |
+
x_in=1
|
144 |
+
else:
|
145 |
+
raise Exception('Dataset not recognized.')
|
146 |
+
if args.model=="GFusion":
|
147 |
+
GFusion_model=GFusion(h_channel=args.h_ch,input_featuresize=x_in,\
|
148 |
+
localdepth=args.localdepth,num_interactions=args.num_interactions,finaldepth=args.finaldepth,share=args.share)
|
149 |
+
GFusion_model.to(device)
|
150 |
+
optimizer = torch.optim.Adam( list(GFusion_model.parameters()), lr=args.lr)
|
151 |
+
if args.fidelity_train:
|
152 |
+
optimizer2 = torch.optim.Adam(fidelity_weight, lr=optimizer.param_groups[0]['lr']*10)
|
153 |
+
scheduler2 = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer2, factor=0.1, patience=args.patience, min_lr=1e-8)
|
154 |
+
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.1, patience=args.patience, min_lr=1e-8)
|
155 |
+
|
156 |
+
def train(GFusion_model):
|
157 |
+
epochloss=0
|
158 |
+
y_hat, y_true,y_hat_logit = [], [], []
|
159 |
+
optimizer.zero_grad()
|
160 |
+
if args.fidelity_train: optimizer2.zero_grad()
|
161 |
+
if args.model=="GFusion":
|
162 |
+
GFusion_model.train()
|
163 |
+
for i,data in enumerate(train_Loader):
|
164 |
+
if num_of_fidelities==2:
|
165 |
+
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
|
166 |
+
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
|
167 |
+
if args.dataset=='syn':
|
168 |
+
x1[:,1]=x1[:,1]+x1[:,2]
|
169 |
+
x1=x1[:,[0,1,3,4]]
|
170 |
+
x2[:,1]=x2[:,1]+x2[:,2]
|
171 |
+
x2=x2[:,[0,1,3,4]]
|
172 |
+
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
|
173 |
+
x2[x2[:,0]>6666,0]=6666
|
174 |
+
# edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
|
175 |
+
datasource=data[0].datasource
|
176 |
+
Y = target1
|
177 |
+
assert(torch.equal(target1,target2))
|
178 |
+
Y[Y>6666]=6666
|
179 |
+
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
|
180 |
+
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
|
181 |
+
"""
|
182 |
+
triplets are not the same for graphs when training
|
183 |
+
"""
|
184 |
+
num_nodes1=x1.shape[0]
|
185 |
+
num_nodes2=x2.shape[0]
|
186 |
+
edge_index_2rd_1, _, _, edx_2nd_1 = triplets(edge_index1, num_nodes1)
|
187 |
+
edge_index_2rd_2, _, _, edx_2nd_2 = triplets(edge_index2, num_nodes2)
|
188 |
+
|
189 |
+
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
|
190 |
+
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
|
191 |
+
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
|
192 |
+
|
193 |
+
if args.dataset=='syn':
|
194 |
+
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
|
195 |
+
else:
|
196 |
+
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
|
197 |
+
|
198 |
+
loss_weight= [weighted_fidelity_weight[i] for i in datasource]
|
199 |
+
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1),loss_weight)
|
200 |
+
"""
|
201 |
+
record predictions
|
202 |
+
"""
|
203 |
+
y_hat += list(pred.detach().numpy().reshape(-1))
|
204 |
+
y_true += list(Y.detach().numpy().reshape(-1))
|
205 |
+
loss=loss1
|
206 |
+
loss.backward()
|
207 |
+
epochloss+=loss
|
208 |
+
optimizer.step()
|
209 |
+
optimizer.zero_grad()
|
210 |
+
if args.fidelity_train:
|
211 |
+
optimizer2.step()
|
212 |
+
optimizer2.zero_grad()
|
213 |
+
reweight_fidelity()
|
214 |
+
return epochloss.item()/len(train_Loader),y_hat, y_true
|
215 |
+
|
216 |
+
def test(loader,GFusion_model,fidelity_weight):
|
217 |
+
if not args.single_high:
|
218 |
+
weighted_fidelity_weight=[i.detach() for i in fidelity_weight]
|
219 |
+
exped_f=[torch.exp(fidelity_weight[i]) for i in range(num_of_fidelities)]
|
220 |
+
fsum=sum(exped_f)
|
221 |
+
for i in range(num_of_fidelities):
|
222 |
+
weighted_fidelity_weight[i]=exped_f[i]/fsum
|
223 |
+
else:
|
224 |
+
weighted_fidelity_weight=[1,0]
|
225 |
+
y_hat, y_true,y_hat_logit = [], [], []
|
226 |
+
loss_total, pred_num = 0, 0
|
227 |
+
GFusion_model.eval()
|
228 |
+
for i,data in enumerate(loader):
|
229 |
+
if num_of_fidelities==2:
|
230 |
+
x1, pos1,edge_index1, batch1,target_index1,target1,is_source1 = data[0].x, data[0].pos,data[0].edge_index, data[0].batch,data[0].target_index,data[0].target,data[0].is_source
|
231 |
+
x2, pos2,edge_index2, batch2,target_index2,target2,is_source2 = data[1].x, data[1].pos,data[1].edge_index, data[1].batch,data[1].target_index,data[1].target,data[1].is_source
|
232 |
+
if args.dataset=='syn':
|
233 |
+
x1[:,1]=x1[:,1]+x1[:,2]
|
234 |
+
x1=x1[:,[0,1,3,4]]
|
235 |
+
x2[:,1]=x2[:,1]+x2[:,2]
|
236 |
+
x2=x2[:,[0,1,3,4]]
|
237 |
+
x1,pos1,target1,x2,pos2,target2=x1.to(torch.float32),pos1.to(torch.float32),target1.to(torch.float32),x2.to(torch.float32),pos2.to(torch.float32),target2.to(torch.float32)
|
238 |
+
x2[x2[:,0]>6666,0]=6666
|
239 |
+
# edge_index,_=add_self_loops(edge_index,num_nodes=x.size(0))
|
240 |
+
datasource=data[0].datasource
|
241 |
+
Y = target1
|
242 |
+
assert(torch.equal(target1,target2))
|
243 |
+
Y[Y>6666]=6666
|
244 |
+
x1, pos1,edge_index1, batch1, target_index1,is_source1 = x1.to(device),pos1.to(device), edge_index1.to(device), batch1.to(device),target_index1.to(device),is_source1.to(device)
|
245 |
+
x2, pos2,edge_index2, batch2, target_index2,is_source2 = x2.to(device),pos2.to(device),edge_index2.to(device), batch2.to(device),target_index2.to(device),is_source2.to(device)
|
246 |
+
|
247 |
+
num_nodes1=x1.shape[0]
|
248 |
+
num_nodes2=x2.shape[0]
|
249 |
+
edge_index_2rd_1, num_2nd_neighbors_1, edx_1st_1, edx_2nd_1 = triplets(edge_index1, num_nodes1)
|
250 |
+
edge_index_2rd_2, num_2nd_neighbors_2, edx_1st_2, edx_2nd_2 = triplets(edge_index2, num_nodes2)
|
251 |
+
pm25_1,pm25_2=GFusion_model([pos1,pos2],[edge_index1,edge_index2],[edge_index_2rd_1,edge_index_2rd_2],\
|
252 |
+
[edx_2nd_1,edx_2nd_2],[batch1,batch2],[x1,x2],[is_source1,is_source2],args.edge_rep)
|
253 |
+
pm25_1,pm25_2=pm25_1[target_index1],pm25_2[target_index2]
|
254 |
+
with torch.no_grad():
|
255 |
+
if args.dataset=='syn':
|
256 |
+
pred=((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
|
257 |
+
else:
|
258 |
+
pred=F.relu((pm25_1*weighted_fidelity_weight[0]+pm25_2*weighted_fidelity_weight[1]).cpu())
|
259 |
+
assert(all(datasource==0))
|
260 |
+
loss1 = criterion(pred.reshape(-1, 1), Y.reshape(-1, 1))*weighted_fidelity_weight[0]
|
261 |
+
"""
|
262 |
+
record predictions
|
263 |
+
"""
|
264 |
+
y_hat += list(pred.detach().numpy().reshape(-1))
|
265 |
+
y_true += list(Y.detach().numpy().reshape(-1))
|
266 |
+
pred_num += len(Y.reshape(-1, 1))
|
267 |
+
loss=loss1
|
268 |
+
loss_total += loss.detach() * len(Y.reshape(-1, 1))
|
269 |
+
return loss_total/pred_num, y_hat, y_true
|
270 |
+
if args.loadmodel:
|
271 |
+
try:
|
272 |
+
suffix='Oct31-11:50:30'
|
273 |
+
GFusion_model.load_state_dict(torch.load(os.path.join("save",args.dataset,'model','best_GFusion_model_'+suffix+'.pth')),strict=True)
|
274 |
+
best_GFusion_model = copy.deepcopy(GFusion_model)
|
275 |
+
except OSError:
|
276 |
+
pass
|
277 |
+
else:
|
278 |
+
best_val_trigger = 1e3
|
279 |
+
old_lr=1e3
|
280 |
+
suffix="{}{}-{}:{}:{}".format(datetime.now().strftime("%h"),
|
281 |
+
datetime.now().strftime("%d"),
|
282 |
+
datetime.now().strftime("%H"),
|
283 |
+
datetime.now().strftime("%M"),
|
284 |
+
datetime.now().strftime("%S"))
|
285 |
+
if args.log:
|
286 |
+
writer = SummaryWriter(os.path.join(tensorboard_dir,suffix))
|
287 |
+
for epoch in range(args.nepoch):
|
288 |
+
if args.model in ['GFusion']: train_loss,y_hat, y_true=train(GFusion_model)
|
289 |
+
if args.log:
|
290 |
+
writer.add_scalar('loss/Train', train_loss, epoch)
|
291 |
+
if args.dataset in ['south',"north",'syn','flu']:
|
292 |
+
train_mae=mean_absolute_error(y_true, y_hat)
|
293 |
+
train_rmse = np.sqrt(mean_squared_error(y_true, y_hat))
|
294 |
+
if args.log:
|
295 |
+
writer.add_scalar('mae/Train', train_mae, epoch)
|
296 |
+
writer.add_scalar('rmse/Train', train_rmse, epoch)
|
297 |
+
print(( f"epoch[{epoch:d}] train_loss : {train_loss:.3f} train_mae : {train_mae:.3f} train_rmse : {train_rmse:.3f}" ))
|
298 |
+
if args.model in ['GFusion']:
|
299 |
+
if args.fidelity_train==True:
|
300 |
+
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
|
301 |
+
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
|
302 |
+
if epoch % args.test_per_round == 0:
|
303 |
+
if args.model in ['GFusion']:
|
304 |
+
val_loss, yhat_val, ytrue_val = test(val_Loader,GFusion_model,fidelity_weight)
|
305 |
+
test_loss, yhat_test, ytrue_test = test(test_Loader,GFusion_model,fidelity_weight)
|
306 |
+
if args.log:
|
307 |
+
writer.add_scalar('loss/val', val_loss, epoch)
|
308 |
+
writer.add_scalar('loss/test', test_loss, epoch)
|
309 |
+
if args.dataset in ['south',"north",'syn','flu']:
|
310 |
+
val_mae=mean_absolute_error(ytrue_val, yhat_val)
|
311 |
+
val_rmse = np.sqrt(mean_squared_error(ytrue_val, yhat_val))
|
312 |
+
if args.log:
|
313 |
+
writer.add_scalar('mae/val', val_mae, epoch)
|
314 |
+
writer.add_scalar('rmse/val', val_rmse, epoch)
|
315 |
+
print(blue( f"epoch[{epoch:d}] val_mae : {val_mae:.3f} val_rmse : {val_rmse:.3f}" ))
|
316 |
+
test_mae = mean_absolute_error(ytrue_test, yhat_test)
|
317 |
+
test_rmse = np.sqrt(mean_squared_error(ytrue_test, yhat_test))
|
318 |
+
test_var=explained_variance_score(ytrue_test,yhat_test)
|
319 |
+
test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
|
320 |
+
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
|
321 |
+
if args.log:
|
322 |
+
writer.add_scalar('mae/test', test_mae, epoch)
|
323 |
+
writer.add_scalar('rmse/test', test_rmse, epoch)
|
324 |
+
print(blue( f"epoch[{epoch:d}] test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_Pearsonr: {test_Pearsonr:.3f} test_coefOfDetermination: {test_coefOfDetermination:.3f}" ))
|
325 |
+
if args.model in ['GFusion']:
|
326 |
+
if args.fidelity_train==True:
|
327 |
+
print(f"fidelity weight: {fidelity_weight[0]:.3f}, {fidelity_weight[1]:.3f}")
|
328 |
+
print(f"weighted_fidelity_weight: {weighted_fidelity_weight[0]:.3f}, {weighted_fidelity_weight[1]:.3f}")
|
329 |
+
val_trigger=val_mae
|
330 |
+
if val_trigger < best_val_trigger:
|
331 |
+
best_val_trigger = val_trigger
|
332 |
+
best_GFusion_model = copy.deepcopy(GFusion_model)
|
333 |
+
best_fidelity=copy.deepcopy(fidelity_weight)
|
334 |
+
best_info=[epoch,val_trigger]
|
335 |
+
"""
|
336 |
+
update lr when epoch≥30
|
337 |
+
"""
|
338 |
+
if epoch >= 30:
|
339 |
+
lr = scheduler.optimizer.param_groups[0]['lr']
|
340 |
+
if old_lr!=lr:
|
341 |
+
print(red('lr'), epoch, (lr), sep=', ')
|
342 |
+
old_lr=lr
|
343 |
+
scheduler.step(val_trigger)
|
344 |
+
if args.fidelity_train:
|
345 |
+
scheduler2.step(val_trigger)
|
346 |
+
val_loss, yhat_val, ytrue_val = test(val_Loader,best_GFusion_model,best_fidelity)
|
347 |
+
test_loss, yhat_test, ytrue_test = test(test_Loader,best_GFusion_model,best_fidelity)
|
348 |
+
if args.dataset in ['south',"north",'syn','flu']:
|
349 |
+
val_mae = mean_absolute_error(ytrue_val, yhat_val)
|
350 |
+
val_rmse=np.sqrt(mean_squared_error(ytrue_val,yhat_val))
|
351 |
+
val_var=explained_variance_score(ytrue_val,yhat_val)
|
352 |
+
print(blue( f"best_val val_mae: {val_mae:.3f} val_rmse: {val_rmse:.3f} val_var: {val_var:.3f}" ))
|
353 |
+
|
354 |
+
test_mae=mean_absolute_error(ytrue_test,yhat_test)
|
355 |
+
test_rmse=np.sqrt(mean_squared_error(ytrue_test,yhat_test))
|
356 |
+
test_var=explained_variance_score(ytrue_test,yhat_test)
|
357 |
+
test_coefOfDetermination=r2_score(ytrue_test,yhat_test)
|
358 |
+
test_Pearsonr=measure_Pearsonr(np.array(yhat_test).reshape(-1, 1),np.array(ytrue_test).reshape(-1))[0]
|
359 |
+
print(blue( f"best_test test_mae: {test_mae:.3f} test_rmse: {test_rmse:.3f} test_var: {test_var:.3f}" ))
|
360 |
+
if not args.loadmodel:
|
361 |
+
"""
|
362 |
+
save training info and best result
|
363 |
+
"""
|
364 |
+
result_file=os.path.join(info_dir, suffix)
|
365 |
+
with open(result_file, 'w') as f:
|
366 |
+
print(args.num_neighbors,args.nepoch,sep=' ',file=f)
|
367 |
+
print(f"fidelity weight: {best_fidelity[0]:.3f}, {best_fidelity[1]:.3f}",file=f)
|
368 |
+
print("Random Seed: ", Seed,file=f)
|
369 |
+
if args.dataset in ['south',"north",'syn','flu']:
|
370 |
+
print(f"MAE val : {val_mae:.3f}, Test : {test_mae:.3f}", file=f)
|
371 |
+
print(f"rmse val : {val_rmse:.3f}, Test : {test_rmse:.3f}", file=f)
|
372 |
+
print(f"var val : {val_var:.3f}, Test : {test_var:.3f}", file=f)
|
373 |
+
print(f"test_coefOfDetermination: {test_coefOfDetermination:.3f}, test_Pearsonr : {test_Pearsonr:.3f}", file=f)
|
374 |
+
print(f"Best info: {best_info}", file=f)
|
375 |
+
for i in [[a,getattr(args, a)] for a in args.__dict__]:
|
376 |
+
print(i,sep='\n',file=f)
|
377 |
+
with open(os.path.join(model_dir,'best_f_weight'+"_"+suffix+".pkl"), 'wb') as handle:
|
378 |
+
pickle.dump(fidelity_weight, handle)
|
379 |
+
torch.save(best_GFusion_model.state_dict(), os.path.join(model_dir,'best_GFusion_model'+"_"+suffix+'.pth') )
|
380 |
+
print("done")
|
381 |
+
|
382 |
+
if __name__ == '__main__':
|
383 |
+
args = get_args()
|
384 |
+
if not os.path.exists(args.save_dir):
|
385 |
+
os.makedirs(args.save_dir,exist_ok=True)
|
386 |
+
tensorboard_dir=os.path.join(args.save_dir,'log')
|
387 |
+
if not os.path.exists(tensorboard_dir):
|
388 |
+
os.makedirs(tensorboard_dir,exist_ok=True)
|
389 |
+
model_dir=os.path.join(args.save_dir,'model')
|
390 |
+
if not os.path.exists(model_dir):
|
391 |
+
os.makedirs(model_dir,exist_ok=True)
|
392 |
+
info_dir=os.path.join(args.save_dir,'info')
|
393 |
+
if not os.path.exists(info_dir):
|
394 |
+
os.makedirs(info_dir,exist_ok=True)
|
395 |
+
Seed = args.man_seed if args.manualSeed else random.randint(1, 10000)
|
396 |
+
print("Random Seed: ", Seed)
|
397 |
+
random.seed(Seed)
|
398 |
+
torch.manual_seed(Seed)
|
399 |
+
np.random.seed(Seed)
|
400 |
+
flag=0
|
401 |
+
if args.dataset in ['south',"north",'syn',"flu"]:
|
402 |
+
graphs1,graphs2=dataset.load_point(args.dataset,args.num_neighbors,[False,200,500])
|
403 |
+
np.random.shuffle(graphs1)
|
404 |
+
val_test_split = int(np.around( 2 / 10 * len(graphs1) ))
|
405 |
+
train_val_split = int(len(graphs1)-2*val_test_split)
|
406 |
+
if args.single_high:
|
407 |
+
train_graphs = graphs1[:train_val_split]
|
408 |
+
else:
|
409 |
+
train_graphs = graphs1[:train_val_split]+graphs2
|
410 |
+
val_graphs = graphs1[train_val_split:train_val_split+val_test_split]
|
411 |
+
test_graphs = graphs1[train_val_split+val_test_split:]
|
412 |
+
|
413 |
+
np.random.shuffle(train_graphs)
|
414 |
+
train_Loader=DataLoader(train_graphs, batch_size=args.batchSize)
|
415 |
+
val_Loader=DataLoader(val_graphs, batch_size=args.batchSize)
|
416 |
+
test_Loader=DataLoader(test_graphs, batch_size=args.batchSize)
|
417 |
+
print(f"train_pair_num: {len(train_graphs)}, val_pair_num: {len(val_graphs)}, test_pair_num: {len(test_graphs)}")
|
418 |
+
else:
|
419 |
+
raise Exception('Dataset not recognized.')
|
420 |
+
main(args,train_Loader,val_Loader,test_Loader)
|
421 |
+
|
utils/utils.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import networkx as nx
|
4 |
+
from networkx.utils import UnionFind
|
5 |
+
|
6 |
+
from typing import Optional
|
7 |
+
import torch
|
8 |
+
from torch import Tensor
|
9 |
+
|
10 |
+
from torch_sparse import SparseTensor
|
11 |
+
from scipy.sparse import csr_matrix
|
12 |
+
from math import pi as PI
|
13 |
+
import torch.nn.functional as F
|
14 |
+
def unique(sequence):
|
15 |
+
seen = set()
|
16 |
+
return [x for x in sequence if not (x in seen or seen.add(x))]
|
17 |
+
def pos2key(pos):
|
18 |
+
pos=pos.reshape(-1)
|
19 |
+
key="{:08.4f}".format(pos[0])+'_'+"{:08.4f}".format(pos[1])
|
20 |
+
return key
|
21 |
+
def get_angle(v1: Tensor, v2: Tensor):
|
22 |
+
if v1.shape[1]==2:
|
23 |
+
v1=F.pad(v1, (0, 1))
|
24 |
+
if v2.shape[1]==2:
|
25 |
+
v2= F.pad(v2, (0, 1))
|
26 |
+
return torch.atan2(
|
27 |
+
torch.cross(v1, v2, dim=1).norm(p=2, dim=1), (v1 * v2).sum(dim=1))
|
28 |
+
class GaussianSmearing(torch.nn.Module):
|
29 |
+
def __init__(self, start=-PI, stop=PI, num_gaussians=12):
|
30 |
+
super(GaussianSmearing, self).__init__()
|
31 |
+
offset = torch.linspace(start, stop, num_gaussians)
|
32 |
+
self.coeff = -0.5 / (offset[1] - offset[0]).item() ** 2
|
33 |
+
self.register_buffer("offset", offset)
|
34 |
+
|
35 |
+
def forward(self, dist):
|
36 |
+
dist = dist.view(-1, 1) - self.offset.view(1, -1)
|
37 |
+
return torch.exp(self.coeff * torch.pow(dist, 2))
|
38 |
+
|
39 |
+
def triplets(edge_index, num_nodes):
|
40 |
+
row, col = edge_index
|
41 |
+
|
42 |
+
value = torch.arange(row.size(0), device=row.device)
|
43 |
+
adj_t = SparseTensor(row=row, col=col, value=value,
|
44 |
+
sparse_sizes=(num_nodes, num_nodes))
|
45 |
+
adj_t_row = adj_t[col]
|
46 |
+
num_triplets = adj_t_row.set_value(None).sum(dim=1).to(torch.long)
|
47 |
+
|
48 |
+
idx_i = row.repeat_interleave(num_triplets)
|
49 |
+
idx_j = col.repeat_interleave(num_triplets)
|
50 |
+
edx_1st = value.repeat_interleave(num_triplets)
|
51 |
+
idx_k = adj_t_row.storage.col()
|
52 |
+
edx_2nd = adj_t_row.storage.value()
|
53 |
+
mask1 = (idx_i == idx_k) & (idx_j != idx_i)
|
54 |
+
mask2 = (idx_i == idx_j) & (idx_j != idx_k)
|
55 |
+
mask3 = (idx_j == idx_k) & (idx_i != idx_k)
|
56 |
+
mask = ~(mask1 | mask2 | mask3)
|
57 |
+
idx_i, idx_j, idx_k, edx_1st, edx_2nd = idx_i[mask], idx_j[mask], idx_k[mask], edx_1st[mask], edx_2nd[mask]
|
58 |
+
|
59 |
+
num_triplets_real = torch.cumsum(num_triplets, dim=0) - torch.cumsum(~mask, dim=0)[torch.cumsum(num_triplets, dim=0)-1]
|
60 |
+
|
61 |
+
return torch.stack([idx_i, idx_j, idx_k]), num_triplets_real.to(torch.long), edx_1st, edx_2nd
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == '__main__':
|
65 |
+
1
|
66 |
+
|
67 |
+
|