| | |
| | |
| | import torch |
| | import numpy as np |
| | import torch.nn.functional as Fd |
| | from deeprobust.graph.defense import GCNJaccard, GCN |
| | from deeprobust.graph.defense import GCNScore |
| | from deeprobust.graph.utils import * |
| | from deeprobust.graph.data import Dataset, PrePtbDataset |
| | from scipy.sparse import csr_matrix |
| | import argparse |
| | import pickle |
| | from deeprobust.graph import utils |
| | from collections import defaultdict |
| | from tqdm import tqdm |
| |
|
| | parser = argparse.ArgumentParser() |
| | parser.add_argument('--seed', type=int, default=15, help='Random seed.') |
| | parser.add_argument('--dataset', type=str, default='pubmed', choices=['cora', 'cora_ml', 'citeseer', 'polblogs', 'pubmed'], help='dataset') |
| | parser.add_argument('--ptb_rate', type=float, default=0.05, help='pertubation rate') |
| |
|
| | args = parser.parse_args() |
| | args.cuda = torch.cuda.is_available() |
| | print('cuda: %s' % args.cuda) |
| | device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu") |
| |
|
| | |
| | np.random.seed(args.seed) |
| | if args.cuda: |
| | torch.cuda.manual_seed(args.seed) |
| |
|
| | |
| | |
| | |
| | |
| | data = Dataset(root='/tmp/', name=args.dataset, setting='prognn') |
| | adj, features, labels = data.adj, data.features, data.labels |
| | idx_train, idx_val, idx_test = data.idx_train, data.idx_val, data.idx_test |
| |
|
| |
|
| | perturbed_data = PrePtbDataset(root='/tmp/', |
| | name=args.dataset, |
| | attack_method='meta', |
| | ptb_rate=args.ptb_rate) |
| |
|
| | perturbed_adj = perturbed_data.adj |
| | |
| |
|
| | def save_cg_scores(cg_scores, filename="cg_scores.npy"): |
| | np.save(filename, cg_scores) |
| | print(f"CG-scores saved to {filename}") |
| |
|
| | def load_cg_scores_numpy(filename="cg_scores.npy"): |
| | cg_scores = np.load(filename, allow_pickle=True) |
| | print(f"CG-scores loaded from {filename}") |
| | return cg_scores |
| |
|
| |
|
| | import torch |
| | import numpy as np |
| | from collections import defaultdict |
| | from tqdm import tqdm |
| |
|
| |
|
| | def calc_cg_score_gnn_with_sampling( |
| | A, X, labels, device, rep_num=1, unbalance_ratio=1, sub_term=False, batch_size=64 |
| | ): |
| | """ |
| | Optimized CG-score calculation with edge batching and GPU acceleration. |
| | """ |
| |
|
| | N = A.shape[0] |
| | cg_scores = { |
| | "vi": np.zeros((N, N)), |
| | "ab": np.zeros((N, N)), |
| | "a2": np.zeros((N, N)), |
| | "b2": np.zeros((N, N)), |
| | "times": np.zeros((N, N)), |
| | } |
| |
|
| | A = A.to(device) |
| | X = X.to(device) |
| | labels = labels.to(device) |
| |
|
| | @torch.no_grad() |
| | def normalize(tensor): |
| | return tensor / (torch.norm(tensor, dim=1, keepdim=True) + 1e-8) |
| |
|
| | for _ in range(rep_num): |
| | AX = torch.matmul(A, X) |
| | norm_AX = normalize(AX) |
| |
|
| | |
| | unique_labels = torch.unique(labels) |
| | label_to_indices = { |
| | label.item(): (labels == label).nonzero(as_tuple=True)[0] for label in unique_labels |
| | } |
| | dataset = {label: norm_AX[indices] for label, indices in label_to_indices.items()} |
| |
|
| | |
| | neg_samples_dict = {} |
| | neg_indices_dict = {} |
| | for label in unique_labels: |
| | label = label.item() |
| | mask = labels != label |
| | neg_samples = norm_AX[mask] |
| | neg_indices = mask.nonzero(as_tuple=True)[0] |
| | neg_samples_dict[label] = neg_samples |
| | neg_indices_dict[label] = neg_indices |
| |
|
| | for curr_label in tqdm(unique_labels.tolist(), desc="Label groups"): |
| | curr_samples = dataset[curr_label] |
| | curr_indices = label_to_indices[curr_label] |
| | curr_num = len(curr_samples) |
| |
|
| | chosen_curr_idx = torch.randperm(curr_num, device=device) |
| | chosen_curr_samples = curr_samples[chosen_curr_idx] |
| | chosen_curr_indices = curr_indices[chosen_curr_idx] |
| |
|
| | neg_samples = neg_samples_dict[curr_label] |
| | neg_indices = neg_indices_dict[curr_label] |
| | neg_num = min(int(curr_num * unbalance_ratio), len(neg_samples)) |
| | rand_idx = torch.randperm(len(neg_samples), device=device)[:neg_num] |
| | chosen_neg_samples = neg_samples[rand_idx] |
| | chosen_neg_indices = neg_indices[rand_idx] |
| |
|
| | combined_samples = torch.cat([chosen_curr_samples, chosen_neg_samples], dim=0) |
| | y = torch.cat([torch.ones(len(chosen_curr_samples)), -torch.ones(neg_num)], dim=0).to(device) |
| |
|
| | |
| | H_inner = torch.matmul(combined_samples, combined_samples.T) |
| | H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| | H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| | H.fill_diagonal_(0.5) |
| | H += 1e-6 * torch.eye(H.size(0), device=device) |
| | invH = torch.inverse(H) |
| | original_error = y @ (invH @ y) |
| |
|
| | |
| | edge_batch = [] |
| | for idx_i in chosen_curr_indices.tolist(): |
| | for j in range(idx_i + 1, N): |
| | if A[idx_i, j] != 0: |
| | edge_batch.append((idx_i, j)) |
| |
|
| | |
| | for k in tqdm(range(0, len(edge_batch), batch_size), desc="Edge batches", leave=False): |
| | batch = edge_batch[k : k + batch_size] |
| | B = len(batch) |
| |
|
| | norm_AX1_batch = norm_AX.repeat(B, 1, 1).clone() |
| | for b, (i, j) in enumerate(batch): |
| | AX1_i = AX[i] - A[i, j] * X[j] |
| | AX1_j = AX[j] - A[j, i] * X[i] |
| | norm_AX1_batch[b, i] = AX1_i / (torch.norm(AX1_i) + 1e-8) |
| | norm_AX1_batch[b, j] = AX1_j / (torch.norm(AX1_j) + 1e-8) |
| |
|
| | sample_idx = chosen_curr_indices.tolist() + chosen_neg_indices.tolist() |
| | sample_batch = norm_AX1_batch[:, sample_idx, :] |
| |
|
| | H_inner = torch.matmul(sample_batch, sample_batch.transpose(1, 2)) |
| | H_inner = torch.clamp(H_inner, min=-1.0, max=1.0) |
| | H = H_inner * (np.pi - torch.acos(H_inner)) / (2 * np.pi) |
| | eye = torch.eye(H.size(-1), device=device).unsqueeze(0).expand_as(H) |
| | H = H + 1e-6 * eye |
| | H.diagonal(dim1=-2, dim2=-1).copy_(0.5) |
| |
|
| | invH = torch.inverse(H) |
| | y_expanded = y.unsqueeze(0).expand(B, -1) |
| | error_A1 = torch.einsum("bi,bij,bj->b", y_expanded, invH, y_expanded) |
| |
|
| | for b, (i, j) in enumerate(batch): |
| | score = (original_error - error_A1[b]).item() |
| | cg_scores["vi"][i, j] += score |
| | cg_scores["vi"][j, i] = score |
| | cg_scores["times"][i, j] += 1 |
| | cg_scores["times"][j, i] += 1 |
| |
|
| | for key in cg_scores: |
| | if key != "times": |
| | cg_scores[key] = cg_scores[key] / np.where(cg_scores["times"] > 0, cg_scores["times"], 1) |
| |
|
| | return cg_scores if sub_term else cg_scores["vi"] |
| |
|
| |
|
| |
|
| | def is_symmetric_sparse(adj): |
| | """ |
| | Check if a sparse matrix is symmetric. |
| | """ |
| | |
| | return (adj != adj.transpose()).nnz == 0 |
| |
|
| | def make_symmetric_sparse(adj): |
| | """ |
| | Ensure the sparse adjacency matrix is symmetrical. |
| | """ |
| | |
| | sym_adj = (adj + adj.transpose()) / 2 |
| | return sym_adj |
| |
|
| | perturbed_adj = make_symmetric_sparse(perturbed_adj) |
| |
|
| | if type(perturbed_adj) is not torch.Tensor: |
| | features, perturbed_adj, labels = utils.to_tensor(features, perturbed_adj, labels) |
| | else: |
| | features = features.to(device) |
| | perturbed_adj = perturbed_adj.to(device) |
| | labels = labels.to(device) |
| |
|
| | if utils.is_sparse_tensor(perturbed_adj): |
| | |
| | adj_norm = utils.normalize_adj_tensor(perturbed_adj, sparse=True) |
| | else: |
| | adj_norm = utils.normalize_adj_tensor(perturbed_adj) |
| |
|
| | features = features.to_dense() |
| | perturbed_adj = adj_norm.to_dense() |
| |
|
| |
|
| | calc_cg_score = calc_cg_score_gnn_with_sampling(perturbed_adj, features, labels, device, rep_num=1, unbalance_ratio=3, sub_term=False, batch_size=512) |
| | save_cg_scores(calc_cg_score, filename="pubmed_0.05.npy") |
| | |
| |
|