clique / GraphUNets /src /utils /data_loader.py
qingy2024's picture
Upload folder using huggingface_hub
bf620c6 verified
import torch
from tqdm import tqdm
import networkx as nx
import numpy as np
import torch.nn.functional as F
from sklearn.model_selection import StratifiedKFold
from functools import partial
class G_data(object):
def __init__(self, num_class, feat_dim, g_list):
self.num_class = num_class
self.feat_dim = feat_dim
self.g_list = g_list
self.sep_data()
def sep_data(self, seed=0):
skf = StratifiedKFold(n_splits=10, shuffle=True, random_state=seed)
labels = [g.label for g in self.g_list]
self.idx_list = list(skf.split(np.zeros(len(labels)), labels))
def use_fold_data(self, fold_idx):
self.fold_idx = fold_idx+1
train_idx, test_idx = self.idx_list[fold_idx]
self.train_gs = [self.g_list[i] for i in train_idx]
self.test_gs = [self.g_list[i] for i in test_idx]
class FileLoader(object):
def __init__(self, args):
self.args = args
def line_genor(self, lines):
for line in lines:
yield line
def gen_graph(self, f, i, label_dict, feat_dict, deg_as_tag):
row = next(f).strip().split()
n, label = [int(w) for w in row]
if label not in label_dict:
label_dict[label] = len(label_dict)
g = nx.Graph()
g.add_nodes_from(list(range(n)))
node_tags = []
for j in range(n):
row = next(f).strip().split()
tmp = int(row[1]) + 2
row = [int(w) for w in row[:tmp]]
if row[0] not in feat_dict:
feat_dict[row[0]] = len(feat_dict)
for k in range(2, len(row)):
if j != row[k]:
g.add_edge(j, row[k])
if len(row) > 2:
node_tags.append(feat_dict[row[0]])
g.label = label
g.remove_nodes_from(list(nx.isolates(g)))
if deg_as_tag:
g.node_tags = list(dict(g.degree).values())
else:
g.node_tags = node_tags
return g
def process_g(self, label_dict, tag2index, tagset, g):
g.label = label_dict[g.label]
g.feas = torch.tensor([tag2index[tag] for tag in g.node_tags])
g.feas = F.one_hot(g.feas, len(tagset))
A = torch.FloatTensor(nx.to_numpy_matrix(g))
g.A = A + torch.eye(g.number_of_nodes())
return g
def load_data(self):
args = self.args
print('loading data ...')
g_list = []
label_dict = {}
feat_dict = {}
with open('data/%s/%s.txt' % (args.data, args.data), 'r') as f:
lines = f.readlines()
f = self.line_genor(lines)
n_g = int(next(f).strip())
for i in tqdm(range(n_g), desc="Create graph", unit='graphs'):
g = self.gen_graph(f, i, label_dict, feat_dict, args.deg_as_tag)
g_list.append(g)
tagset = set([])
for g in g_list:
tagset = tagset.union(set(g.node_tags))
tagset = list(tagset)
tag2index = {tagset[i]: i for i in range(len(tagset))}
f_n = partial(self.process_g, label_dict, tag2index, tagset)
new_g_list = []
for g in tqdm(g_list, desc="Process graph", unit='graphs'):
new_g_list.append(f_n(g))
num_class = len(label_dict)
feat_dim = len(tagset)
print('# classes: %d' % num_class, '# maximum node tag: %d' % feat_dim)
return G_data(num_class, feat_dim, new_g_list)