from __future__ import print_function import torch.utils.data as data import os import os.path import torch import numpy as np import pandas as pd import sys from torch_geometric.nn import knn_graph from torch_geometric.data import Data from torch_geometric.loader import DataLoader from torch_geometric.utils import add_self_loops from torch_geometric.data.collate import collate from torch_geometric.data.separate import separate import pickle import time from torch_geometric.data.data import BaseData from torch_geometric.data.storage import BaseStorage from typing import Any def mycollate(data_list): r"""Collates a Python list of :obj:`torch_geometric.data.Data` objects to the internal storage format of :class:`~torch_geometric.data.InMemoryDataset`.""" if len(data_list) == 1: return data_list[0], None data, slices, _ = collate( data_list[0].__class__, data_list=data_list, increment=False, add_batch=False, ) return data, slices def myseparate(cls, batch: BaseData, idx: int, slice_dict: Any) -> BaseData: data = cls().stores_as(batch) # We iterate over each storage object and recursively separate all its attributes: for batch_store, data_store in zip(batch.stores, data.stores): attrs = set(batch_store.keys()) for attr in attrs: slices = slice_dict[attr] data_store[attr] = _separate(attr, batch_store[attr], idx, slices, batch, batch_store) return data def _separate( key: str, value: Any, idx: int, slices: Any, batch: BaseData, store: BaseStorage, ) : # Narrow a `torch.Tensor` based on `slices`. key = str(key) cat_dim = batch.__cat_dim__(key, value, store) start, end = int(slices[idx]), int(slices[idx + 1]) value = value.narrow(cat_dim or 0, start, end - start) return value def load_point(datasetname="south",k=5,small=[False,50,100]): """ load point and build graph pairs """ print("loading") time1=time.time() if small[0]: print("small south dataset k=5") datasetname="south" k=5 filename=os.path.join("data",datasetname,datasetname+f'_{k}.pt') [data_graphs1,slices_graphs1,data_graphs2,slices_graphs2]=torch.load(filename) flattened_list_graphs1 = [myseparate(cls=data_graphs1.__class__, batch=data_graphs1,idx=i,slice_dict=slices_graphs1) for i in range(small[1]*2)] flattened_list_graphs2 = [myseparate(cls=data_graphs2.__class__, batch=data_graphs2,idx=i,slice_dict=slices_graphs2) for i in range(small[2]*2)] unflattened_list_graphs1= [flattened_list_graphs1[n:n+2] for n in range(0, len(flattened_list_graphs1), 2)] unflattened_list_graphs2= [flattened_list_graphs2[n:n+2] for n in range(0, len(flattened_list_graphs2), 2)] print(f"Load data used {time.time()-time1:.1f} seconds") return unflattened_list_graphs1,unflattened_list_graphs2 return process(datasetname,k) def process(datasetname="south",k=5): time1=time.time() """ build graph pairs """ point_path= os.path.join("data",datasetname,datasetname+".pkl") with open(point_path, 'rb') as f: data = pickle.load(f) graphs1=[] graphs2=[] for day in data: day_d1=day[0] day_d2=day[1] assert(len(day_d1)