temporal-twins-code / src /tgn /evaluate.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
import torch
from sklearn.metrics import roc_auc_score, average_precision_score
from src.tgn.time_encoding import TimeEncoding
from src.tgn.memory import Memory
def evaluate(model, memory, graph_data, norm_stats):
device = torch.device("cpu")
edge_index = torch.tensor(graph_data["edge_index"], dtype=torch.long)
edge_attr = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)
labels = torch.tensor(graph_data["y"], dtype=torch.float32)
x = torch.tensor(graph_data["x"], dtype=torch.float32).to(device)
x = (x - x.mean(dim=0)) / (x.std(dim=0) + 1e-6)
# Apply SAME normalization as training
edge_attr = (edge_attr - norm_stats["ea_mean"]) / norm_stats["ea_std"]
timestamps = torch.tensor(graph_data["edge_attr"], dtype=torch.float32)[:, 1]
timestamps = (timestamps - norm_stats["t_min"]) / (norm_stats["t_max"] - norm_stats["t_min"] + 1e-6)
test_idx = graph_data["test_idx"]
train_idx = graph_data["train_idx"]
# Rebuild memory from train edges only
memory = Memory(x.shape[0], memory_dim=64, device=device)
time_encoder = TimeEncoding(16).to(device)
batch_size = 1024
with torch.no_grad():
for i in range(0, len(train_idx), batch_size):
batch_ids = train_idx[i:i + batch_size]
u_i = edge_index[0, batch_ids]
v_i = edge_index[1, batch_ids]
edge_feat_i = edge_attr[batch_ids]
t_i = timestamps[batch_ids]
time_enc_i = time_encoder(t_i)
h_u_i = memory.get(u_i)
h_v_i = memory.get(v_i)
msg = model.compute_message(
h_u_i.detach(), h_v_i.detach(),
edge_feat_i, time_enc_i
)
node_ids = torch.cat([u_i, v_i])
messages = torch.cat([msg, msg])
unique_nodes, inverse_idx = torch.unique(node_ids, return_inverse=True)
agg_msg = torch.zeros_like(memory.memory[unique_nodes])
agg_msg.index_add_(0, inverse_idx, messages)
counts = torch.bincount(inverse_idx).unsqueeze(1)
agg_msg = agg_msg / counts
memory.update(unique_nodes, agg_msg)
# Evaluate on test set
u = edge_index[0, test_idx].to(device)
v = edge_index[1, test_idx].to(device)
h_u = memory.get(u)
h_v = memory.get(v)
x_u = x[u]
x_v = x[v]
edge_feat = edge_attr[test_idx].to(device)
with torch.no_grad():
t = timestamps[test_idx].to(device)
time_enc = time_encoder(t)
logits = model.predict(h_u, h_v, edge_feat, x_u, x_v, time_enc)
probs = torch.sigmoid(logits).cpu().numpy()
y_true = labels[test_idx].cpu().numpy()
roc = roc_auc_score(y_true, probs)
pr = average_precision_score(y_true, probs)
return roc, pr, probs, y_true