Update data/processor.py
Browse files- data/processor.py +133 -8
data/processor.py
CHANGED
|
@@ -1,30 +1,155 @@
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from torch_geometric.data import Data
|
|
|
|
|
|
|
| 4 |
|
| 5 |
class GraphProcessor:
|
| 6 |
-
"""
|
| 7 |
|
| 8 |
@staticmethod
|
| 9 |
-
def normalize_features(x):
|
| 10 |
"""Normalize node features"""
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
@staticmethod
|
| 14 |
def add_self_loops(edge_index, num_nodes):
|
| 15 |
"""Add self loops to graph"""
|
| 16 |
-
self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1)
|
| 17 |
edge_index = torch.cat([edge_index, self_loops], dim=1)
|
| 18 |
return edge_index
|
| 19 |
|
| 20 |
@staticmethod
|
| 21 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
"""Move data to device safely"""
|
| 23 |
if hasattr(data, 'to'):
|
| 24 |
return data.to(device)
|
| 25 |
elif isinstance(data, (list, tuple)):
|
| 26 |
-
return [GraphProcessor.
|
| 27 |
elif isinstance(data, dict):
|
| 28 |
-
return {k: GraphProcessor.
|
| 29 |
else:
|
| 30 |
-
return data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch.nn.functional as F
|
| 3 |
from torch_geometric.data import Data
|
| 4 |
+
from torch_geometric.transforms import Compose
|
| 5 |
+
import numpy as np
|
| 6 |
|
| 7 |
class GraphProcessor:
|
| 8 |
+
"""Advanced data preprocessing utilities"""
|
| 9 |
|
| 10 |
@staticmethod
|
| 11 |
+
def normalize_features(x, method='l2'):
|
| 12 |
"""Normalize node features"""
|
| 13 |
+
if method == 'l2':
|
| 14 |
+
return F.normalize(x, p=2, dim=1)
|
| 15 |
+
elif method == 'minmax':
|
| 16 |
+
x_min = x.min(dim=0, keepdim=True)[0]
|
| 17 |
+
x_max = x.max(dim=0, keepdim=True)[0]
|
| 18 |
+
return (x - x_min) / (x_max - x_min + 1e-8)
|
| 19 |
+
elif method == 'standard':
|
| 20 |
+
return (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-8)
|
| 21 |
+
else:
|
| 22 |
+
return x
|
| 23 |
|
| 24 |
@staticmethod
|
| 25 |
def add_self_loops(edge_index, num_nodes):
|
| 26 |
"""Add self loops to graph"""
|
| 27 |
+
self_loops = torch.arange(num_nodes, device=edge_index.device).unsqueeze(0).repeat(2, 1)
|
| 28 |
edge_index = torch.cat([edge_index, self_loops], dim=1)
|
| 29 |
return edge_index
|
| 30 |
|
| 31 |
@staticmethod
|
| 32 |
+
def remove_self_loops(edge_index):
|
| 33 |
+
"""Remove self loops from graph"""
|
| 34 |
+
mask = edge_index[0] != edge_index[1]
|
| 35 |
+
return edge_index[:, mask]
|
| 36 |
+
|
| 37 |
+
@staticmethod
|
| 38 |
+
def add_positional_features(data, encoding_dim=8):
|
| 39 |
+
"""Add positional encodings as features"""
|
| 40 |
+
num_nodes = data.num_nodes
|
| 41 |
+
|
| 42 |
+
# Random walk positional encoding
|
| 43 |
+
if data.edge_index.shape[1] > 0:
|
| 44 |
+
adj = torch.zeros(num_nodes, num_nodes)
|
| 45 |
+
adj[data.edge_index[0], data.edge_index[1]] = 1
|
| 46 |
+
adj = adj + adj.t() # Make symmetric
|
| 47 |
+
|
| 48 |
+
# Degree normalization
|
| 49 |
+
degree = adj.sum(dim=1)
|
| 50 |
+
degree[degree == 0] = 1 # Avoid division by zero
|
| 51 |
+
D_inv_sqrt = torch.diag(1.0 / torch.sqrt(degree))
|
| 52 |
+
|
| 53 |
+
# Normalized adjacency
|
| 54 |
+
A_norm = D_inv_sqrt @ adj @ D_inv_sqrt
|
| 55 |
+
|
| 56 |
+
# Random walk features
|
| 57 |
+
rw_features = []
|
| 58 |
+
A_power = torch.eye(num_nodes)
|
| 59 |
+
|
| 60 |
+
for k in range(encoding_dim):
|
| 61 |
+
A_power = A_power @ A_norm
|
| 62 |
+
rw_features.append(A_power.diag().unsqueeze(1))
|
| 63 |
+
|
| 64 |
+
pos_encoding = torch.cat(rw_features, dim=1)
|
| 65 |
+
else:
|
| 66 |
+
# No edges - use node indices
|
| 67 |
+
pos_encoding = torch.zeros(num_nodes, encoding_dim)
|
| 68 |
+
for i in range(min(encoding_dim, num_nodes)):
|
| 69 |
+
pos_encoding[i, i] = 1.0
|
| 70 |
+
|
| 71 |
+
# Concatenate with existing features
|
| 72 |
+
if data.x is not None:
|
| 73 |
+
data.x = torch.cat([data.x, pos_encoding], dim=1)
|
| 74 |
+
else:
|
| 75 |
+
data.x = pos_encoding
|
| 76 |
+
|
| 77 |
+
return data
|
| 78 |
+
|
| 79 |
+
@staticmethod
|
| 80 |
+
def augment_graph(data, aug_type='edge_drop', aug_ratio=0.1):
|
| 81 |
+
"""Graph augmentation for training"""
|
| 82 |
+
if aug_type == 'edge_drop':
|
| 83 |
+
# Randomly drop edges
|
| 84 |
+
num_edges = data.edge_index.shape[1]
|
| 85 |
+
mask = torch.rand(num_edges) > aug_ratio
|
| 86 |
+
data.edge_index = data.edge_index[:, mask]
|
| 87 |
+
|
| 88 |
+
elif aug_type == 'node_drop':
|
| 89 |
+
# Randomly drop nodes
|
| 90 |
+
num_nodes = data.num_nodes
|
| 91 |
+
keep_mask = torch.rand(num_nodes) > aug_ratio
|
| 92 |
+
keep_nodes = torch.where(keep_mask)[0]
|
| 93 |
+
|
| 94 |
+
# Update edge index
|
| 95 |
+
node_map = torch.full((num_nodes,), -1, dtype=torch.long)
|
| 96 |
+
node_map[keep_nodes] = torch.arange(len(keep_nodes))
|
| 97 |
+
|
| 98 |
+
# Filter edges
|
| 99 |
+
edge_mask = keep_mask[data.edge_index[0]] & keep_mask[data.edge_index[1]]
|
| 100 |
+
filtered_edges = data.edge_index[:, edge_mask]
|
| 101 |
+
data.edge_index = node_map[filtered_edges]
|
| 102 |
+
|
| 103 |
+
# Update features
|
| 104 |
+
data.x = data.x[keep_nodes]
|
| 105 |
+
if hasattr(data, 'y') and data.y.size(0) == num_nodes:
|
| 106 |
+
data.y = data.y[keep_nodes]
|
| 107 |
+
|
| 108 |
+
elif aug_type == 'feature_noise':
|
| 109 |
+
# Add Gaussian noise to features
|
| 110 |
+
if data.x is not None:
|
| 111 |
+
noise = torch.randn_like(data.x) * aug_ratio
|
| 112 |
+
data.x = data.x + noise
|
| 113 |
+
|
| 114 |
+
elif aug_type == 'feature_mask':
|
| 115 |
+
# Randomly mask features
|
| 116 |
+
if data.x is not None:
|
| 117 |
+
mask = torch.rand_like(data.x) > aug_ratio
|
| 118 |
+
data.x = data.x * mask
|
| 119 |
+
|
| 120 |
+
return data
|
| 121 |
+
|
| 122 |
+
@staticmethod
|
| 123 |
+
def to_device_safe(data, device):
|
| 124 |
"""Move data to device safely"""
|
| 125 |
if hasattr(data, 'to'):
|
| 126 |
return data.to(device)
|
| 127 |
elif isinstance(data, (list, tuple)):
|
| 128 |
+
return [GraphProcessor.to_device_safe(item, device) for item in data]
|
| 129 |
elif isinstance(data, dict):
|
| 130 |
+
return {k: GraphProcessor.to_device_safe(v, device) for k, v in data.items()}
|
| 131 |
else:
|
| 132 |
+
return data
|
| 133 |
+
|
| 134 |
+
@staticmethod
|
| 135 |
+
def validate_data(data):
|
| 136 |
+
"""Validate graph data integrity"""
|
| 137 |
+
errors = []
|
| 138 |
+
|
| 139 |
+
# Check basic structure
|
| 140 |
+
if not hasattr(data, 'edge_index'):
|
| 141 |
+
errors.append("Missing edge_index")
|
| 142 |
+
elif data.edge_index.shape[0] != 2:
|
| 143 |
+
errors.append("edge_index must have shape (2, num_edges)")
|
| 144 |
+
|
| 145 |
+
if hasattr(data, 'x') and data.x is not None:
|
| 146 |
+
if hasattr(data, 'num_nodes') and data.x.shape[0] != data.num_nodes:
|
| 147 |
+
errors.append("Feature matrix size mismatch")
|
| 148 |
+
|
| 149 |
+
# Check edge indices
|
| 150 |
+
if hasattr(data, 'edge_index') and data.edge_index.shape[1] > 0:
|
| 151 |
+
max_idx = data.edge_index.max().item()
|
| 152 |
+
if hasattr(data, 'num_nodes') and max_idx >= data.num_nodes:
|
| 153 |
+
errors.append("Edge indices exceed number of nodes")
|
| 154 |
+
|
| 155 |
+
return errors
|