temporal-twins-code / models /sequence_gru.py
temporal-twins-anon's picture
Add anonymous Temporal Twins code release
a3682cf verified
from __future__ import annotations
import copy
from typing import List
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
from sklearn.metrics import average_precision_score, roc_auc_score
from models.base import TemporalModel
_BLOCKED_COLS = frozenset({
"motif_hit_count", "motif_source", "trigger_event_idx", "label_event_idx",
"label_delay", "is_fallback_label", "fraud_source",
"twin_role", "twin_label", "twin_pair_id", "template_id",
"dynamic_fraud_state", "motif_chain_state", "motif_strength",
})
def _safe_roc_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float:
y_true = np.asarray(y_true, dtype=np.float32)
y_prob = np.asarray(y_prob, dtype=np.float32)
if len(y_true) == 0 or len(np.unique(y_true)) < 2:
return 0.5
return float(roc_auc_score(y_true, y_prob))
def _safe_pr_auc(y_true: np.ndarray, y_prob: np.ndarray) -> float:
y_true = np.asarray(y_true, dtype=np.float32)
y_prob = np.asarray(y_prob, dtype=np.float32)
positives = float(np.sum(y_true == 1))
negatives = float(np.sum(y_true == 0))
if positives == 0.0:
return 0.0
if negatives == 0.0:
return 1.0
return float(average_precision_score(y_true, y_prob))
class _SeqGRU(nn.Module):
def __init__(
self,
num_buckets: int,
numeric_dim: int,
emb_dim: int = 32,
pos_dim: int = 16,
time_dim: int = 24,
hidden_dim: int = 64,
max_positions: int = 256,
):
super().__init__()
self.receiver_emb = nn.Embedding(num_buckets + 1, emb_dim)
self.position_emb = nn.Embedding(max_positions + 1, pos_dim)
self.numeric_proj = nn.Sequential(
nn.Linear(numeric_dim, time_dim),
nn.ReLU(),
nn.LayerNorm(time_dim),
)
self.input_proj = nn.Sequential(
nn.Linear(emb_dim + pos_dim + time_dim, hidden_dim),
nn.ReLU(),
)
self.gru = nn.GRU(
input_size=hidden_dim,
hidden_size=hidden_dim,
batch_first=True,
bidirectional=False,
)
self.attn = nn.Sequential(
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)
self.head = nn.Sequential(
nn.LayerNorm(hidden_dim * 3),
nn.Linear(hidden_dim * 3, hidden_dim),
nn.ReLU(),
nn.Dropout(0.10),
nn.Linear(hidden_dim, 1),
)
def forward(
self,
receiver_ids: torch.Tensor,
numeric_feats: torch.Tensor,
positions: torch.Tensor,
lengths: torch.Tensor,
) -> torch.Tensor:
emb = self.receiver_emb(receiver_ids)
pos_emb = self.position_emb(positions)
time_repr = self.numeric_proj(numeric_feats)
x = torch.cat([emb, pos_emb, time_repr], dim=-1)
x = self.input_proj(x)
h_seq, _ = self.gru(x)
batch_size, seq_len, hidden_dim = h_seq.shape
mask = (
torch.arange(seq_len, device=lengths.device).unsqueeze(0)
< lengths.unsqueeze(1)
)
masked_h = h_seq.masked_fill(~mask.unsqueeze(-1), -1e9)
attn_scores = self.attn(h_seq).squeeze(-1).masked_fill(~mask, -1e9)
attn_weights = torch.softmax(attn_scores, dim=1)
attn_pool = (h_seq * attn_weights.unsqueeze(-1)).sum(dim=1)
max_hidden = masked_h.max(dim=1).values
sum_hidden = (h_seq * mask.unsqueeze(-1)).sum(dim=1)
mean_hidden = sum_hidden / lengths.clamp(min=1).unsqueeze(1)
pooled = torch.cat([attn_pool, max_hidden, mean_hidden], dim=-1)
logits = self.head(pooled).squeeze(-1)
return logits
class SequenceGRUWrapper(TemporalModel):
def __init__(
self,
hidden_dim: int = 64,
receiver_buckets: int = 256,
max_positions: int = 256,
device: str = "cpu",
):
self.hidden_dim = hidden_dim
self.receiver_buckets = receiver_buckets
self.max_positions = max_positions
self.device = torch.device(device)
self._model: _SeqGRU | None = None
self._constant_prob: float | None = None
@property
def name(self) -> str:
return "SeqGRU"
@property
def is_temporal(self) -> bool:
return True
def fit(self, df_train: pd.DataFrame, num_epochs: int = 3) -> None:
self._model = _SeqGRU(
num_buckets=self.receiver_buckets,
numeric_dim=6,
emb_dim=32,
hidden_dim=self.hidden_dim,
max_positions=self.max_positions,
).to(self.device)
self._constant_prob = None
def _receiver_token(self, receiver_ids: np.ndarray) -> np.ndarray:
receiver_ids = np.asarray(receiver_ids, dtype=np.int64)
local_map: dict[int, int] = {}
next_token = 1
tokens = np.zeros(len(receiver_ids), dtype=np.int64)
for idx, receiver_id in enumerate(receiver_ids.tolist()):
if receiver_id not in local_map:
local_map[receiver_id] = min(next_token, self.receiver_buckets)
next_token += 1
tokens[idx] = local_map[receiver_id]
return tokens
def _build_event_numeric(self, group: pd.DataFrame) -> np.ndarray:
group = group.sort_values("timestamp").reset_index(drop=True)
timestamps = group["timestamp"].to_numpy(dtype=np.float64)
dts = np.diff(timestamps, prepend=timestamps[0])
dts = np.maximum(dts, 0.0)
phase = (timestamps % 86400.0) / 86400.0
amount = group["amount"].to_numpy(dtype=np.float32) if "amount" in group.columns else np.zeros(len(group), dtype=np.float32)
retry = group["is_retry"].to_numpy(dtype=np.float32) if "is_retry" in group.columns else np.zeros(len(group), dtype=np.float32)
failed = group["failed"].to_numpy(dtype=np.float32) if "failed" in group.columns else np.zeros(len(group), dtype=np.float32)
return np.stack(
[
np.log1p(dts).astype(np.float32),
np.log1p(np.maximum(amount, 0.0)).astype(np.float32),
retry.astype(np.float32),
failed.astype(np.float32),
np.sin(2.0 * np.pi * phase).astype(np.float32),
np.cos(2.0 * np.pi * phase).astype(np.float32),
],
axis=1,
)
def _finalize_sequence(
self,
receiver_ids: np.ndarray,
numeric: np.ndarray,
perm: np.ndarray | None = None,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
receiver_ids = np.asarray(receiver_ids, dtype=np.int64)
numeric = np.asarray(numeric, dtype=np.float32)
if perm is not None and len(receiver_ids):
receiver_ids = receiver_ids[perm]
numeric = numeric[perm]
receiver_tokens = self._receiver_token(receiver_ids)
positions = np.minimum(
np.arange(len(receiver_tokens), dtype=np.int64),
self.max_positions,
)
return receiver_tokens, numeric.astype(np.float32), positions
def _pad_example_batch(
self,
receiver_seqs: list[np.ndarray],
numeric_seqs: list[np.ndarray],
position_seqs: list[np.ndarray],
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
lengths = np.array([len(seq) for seq in receiver_seqs], dtype=np.int64)
max_len = int(max(lengths.max() if len(lengths) else 1, 1))
recv_batch = np.zeros((len(receiver_seqs), max_len), dtype=np.int64)
feat_batch = np.zeros((len(receiver_seqs), max_len, 6), dtype=np.float32)
pos_batch = np.zeros((len(receiver_seqs), max_len), dtype=np.int64)
for idx, (receiver_ids, numeric, positions) in enumerate(zip(receiver_seqs, numeric_seqs, position_seqs)):
seq_len = len(receiver_ids)
recv_batch[idx, :seq_len] = receiver_ids
feat_batch[idx, :seq_len, :] = numeric
pos_batch[idx, :seq_len] = positions
return (
torch.tensor(recv_batch, dtype=torch.long, device=self.device),
torch.tensor(feat_batch, dtype=torch.float32, device=self.device),
torch.tensor(pos_batch, dtype=torch.long, device=self.device),
torch.tensor(lengths, dtype=torch.long, device=self.device),
)
def _build_sequences(self, df: pd.DataFrame, eval_nodes: List[int]):
leaked = _BLOCKED_COLS & set(df.columns)
assert not leaked, f"Oracle columns leaked into SeqGRU: {leaked}"
df = df.sort_values("timestamp").reset_index(drop=True).copy()
groups = {int(sender_id): group for sender_id, group in df.groupby("sender_id", sort=False)}
sequences = []
lengths = []
for node_id in eval_nodes:
group = groups.get(int(node_id))
if group is None or group.empty:
receiver_ids = np.zeros((1,), dtype=np.int64)
numeric = np.zeros((1, 6), dtype=np.float32)
else:
receiver_ids, numeric, _ = self._finalize_sequence(
group["receiver_id"].to_numpy(dtype=np.int64),
self._build_event_numeric(group),
)
sequences.append((receiver_ids, numeric))
lengths.append(len(receiver_ids))
max_len = max(lengths) if lengths else 1
recv_batch = np.zeros((len(eval_nodes), max_len), dtype=np.int64)
feat_batch = np.zeros((len(eval_nodes), max_len, 6), dtype=np.float32)
pos_batch = np.zeros((len(eval_nodes), max_len), dtype=np.int64)
for idx, (receiver_ids, numeric) in enumerate(sequences):
seq_len = len(receiver_ids)
recv_batch[idx, :seq_len] = receiver_ids
feat_batch[idx, :seq_len, :] = numeric
pos_batch[idx, :seq_len] = np.minimum(
np.arange(seq_len, dtype=np.int64),
self.max_positions,
)
return (
torch.tensor(recv_batch, dtype=torch.long, device=self.device),
torch.tensor(feat_batch, dtype=torch.float32, device=self.device),
torch.tensor(pos_batch, dtype=torch.long, device=self.device),
torch.tensor(lengths, dtype=torch.long, device=self.device),
)
def _build_matched_example_dataset(
self,
df: pd.DataFrame,
examples: pd.DataFrame,
shuffle_within_sequence: bool = False,
seed: int = 0,
) -> dict:
if examples.empty:
return {
"receiver_seqs": [],
"numeric_seqs": [],
"position_seqs": [],
"labels": np.zeros(0, dtype=np.float32),
"pair_event_ids": np.zeros(0, dtype=np.int64),
}
df = df.sort_values("timestamp").reset_index(drop=True).copy()
if "local_event_idx" not in df.columns:
df["local_event_idx"] = df.groupby("sender_id").cumcount().astype(np.int32)
groups = {
int(sender_id): group.reset_index(drop=True).copy()
for sender_id, group in df.groupby("sender_id", sort=False)
}
receiver_seqs: list[np.ndarray] = []
numeric_seqs: list[np.ndarray] = []
position_seqs: list[np.ndarray] = []
labels: list[float] = []
pair_event_ids: list[int] = []
for row in examples.itertuples(index=False):
sender_id = int(row.sender_id)
group = groups.get(sender_id)
if group is None or group.empty:
receiver_tokens = np.zeros((1,), dtype=np.int64)
numeric = np.zeros((1, 6), dtype=np.float32)
positions = np.zeros((1,), dtype=np.int64)
else:
end_idx = int(row.eval_local_event_idx)
prefix = group.iloc[: end_idx + 1].copy()
receiver_ids = prefix["receiver_id"].to_numpy(dtype=np.int64)
numeric = self._build_event_numeric(prefix)
perm = None
if shuffle_within_sequence and len(receiver_ids) > 1:
rng = np.random.default_rng(seed + int(row.pair_event_id) * 97 + int(row.label) * 13)
perm = rng.permutation(len(receiver_ids))
receiver_tokens, numeric, positions = self._finalize_sequence(
receiver_ids,
numeric,
perm=perm,
)
receiver_seqs.append(receiver_tokens)
numeric_seqs.append(numeric)
position_seqs.append(positions)
labels.append(float(row.label))
pair_event_ids.append(int(row.pair_event_id))
return {
"receiver_seqs": receiver_seqs,
"numeric_seqs": numeric_seqs,
"position_seqs": position_seqs,
"labels": np.asarray(labels, dtype=np.float32),
"pair_event_ids": np.asarray(pair_event_ids, dtype=np.int64),
}
def _dataset_subset(self, dataset: dict, idx: np.ndarray) -> dict:
idx_list = idx.tolist()
return {
"receiver_seqs": [dataset["receiver_seqs"][i] for i in idx_list],
"numeric_seqs": [dataset["numeric_seqs"][i] for i in idx_list],
"position_seqs": [dataset["position_seqs"][i] for i in idx_list],
"labels": dataset["labels"][idx],
"pair_event_ids": dataset["pair_event_ids"][idx],
}
def _predict_dataset(self, dataset: dict, batch_size: int = 256) -> np.ndarray:
if self._constant_prob is not None:
return np.full(len(dataset["labels"]), self._constant_prob, dtype=np.float32)
assert self._model is not None, "Call fit() first."
if len(dataset["labels"]) == 0:
return np.zeros(0, dtype=np.float32)
self._model.eval()
preds: list[np.ndarray] = []
with torch.no_grad():
for start in range(0, len(dataset["labels"]), batch_size):
end = min(len(dataset["labels"]), start + batch_size)
receiver_ids, numeric_feats, positions, lengths = self._pad_example_batch(
dataset["receiver_seqs"][start:end],
dataset["numeric_seqs"][start:end],
dataset["position_seqs"][start:end],
)
logits = self._model(receiver_ids, numeric_feats, positions, lengths)
preds.append(torch.sigmoid(logits).cpu().numpy().astype(np.float32))
return np.concatenate(preds, axis=0)
def fit_matched_prefix_examples(
self,
df_train: pd.DataFrame,
train_examples: pd.DataFrame,
seed: int = 0,
max_epochs: int = 32,
patience: int = 6,
valid_frac: float = 0.20,
pair_batch_size: int = 64,
learning_rate: float = 2e-3,
weight_decay: float = 1e-4,
shuffle_within_sequence: bool = False,
) -> dict:
assert self._model is not None, "Call fit() first."
dataset = self._build_matched_example_dataset(
df_train,
train_examples,
shuffle_within_sequence=shuffle_within_sequence,
seed=seed,
)
y = dataset["labels"]
if len(y) == 0 or len(np.unique(y)) < 2:
self._constant_prob = float(y.mean()) if len(y) else 0.0
return {
"best_epoch": 0,
"best_valid_roc_auc": float("nan"),
"best_valid_pr_auc": float("nan"),
"train_examples": int(len(y)),
"valid_examples": 0,
}
pair_ids = np.unique(dataset["pair_event_ids"])
rng = np.random.default_rng(seed)
shuffled_pair_ids = rng.permutation(pair_ids)
valid_pairs = int(max(1, round(len(shuffled_pair_ids) * valid_frac))) if len(shuffled_pair_ids) >= 5 else 0
if valid_pairs >= len(shuffled_pair_ids):
valid_pairs = max(1, len(shuffled_pair_ids) - 1)
valid_pair_ids = set(shuffled_pair_ids[:valid_pairs].tolist()) if valid_pairs > 0 else set()
valid_mask = np.isin(dataset["pair_event_ids"], list(valid_pair_ids)) if valid_pair_ids else np.zeros(len(y), dtype=bool)
train_mask = ~valid_mask
train_idx = np.flatnonzero(train_mask)
valid_idx = np.flatnonzero(valid_mask)
if len(train_idx) == 0:
train_idx = np.arange(len(y))
valid_idx = np.zeros(0, dtype=np.int64)
train_dataset = self._dataset_subset(dataset, train_idx)
valid_dataset = self._dataset_subset(dataset, valid_idx) if len(valid_idx) else None
train_pair_order = np.unique(train_dataset["pair_event_ids"])
pair_to_indices: dict[int, list[int]] = {}
for idx, pair_event_id in enumerate(train_dataset["pair_event_ids"].tolist()):
pair_to_indices.setdefault(int(pair_event_id), []).append(idx)
optimizer = torch.optim.AdamW(
self._model.parameters(),
lr=learning_rate,
weight_decay=weight_decay,
)
loss_fn = nn.BCEWithLogitsLoss()
best_state = copy.deepcopy(self._model.state_dict())
best_epoch = 0
best_valid_roc = -np.inf
best_valid_pr = float("nan")
stale_epochs = 0
n_epochs = max(12, max_epochs)
for epoch in range(n_epochs):
self._model.train()
epoch_pair_ids = rng.permutation(train_pair_order)
for start in range(0, len(epoch_pair_ids), pair_batch_size):
batch_pair_ids = epoch_pair_ids[start : start + pair_batch_size]
batch_indices: list[int] = []
for pair_event_id in batch_pair_ids.tolist():
batch_indices.extend(pair_to_indices[int(pair_event_id)])
receiver_ids, numeric_feats, positions, lengths = self._pad_example_batch(
[train_dataset["receiver_seqs"][i] for i in batch_indices],
[train_dataset["numeric_seqs"][i] for i in batch_indices],
[train_dataset["position_seqs"][i] for i in batch_indices],
)
labels = torch.tensor(
train_dataset["labels"][batch_indices],
dtype=torch.float32,
device=self.device,
)
logits = self._model(receiver_ids, numeric_feats, positions, lengths)
loss = loss_fn(logits, labels)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0)
optimizer.step()
if valid_dataset is None or len(valid_dataset["labels"]) == 0:
best_state = copy.deepcopy(self._model.state_dict())
best_epoch = epoch + 1
continue
valid_probs = self._predict_dataset(valid_dataset)
valid_roc = _safe_roc_auc(valid_dataset["labels"], valid_probs)
valid_pr = _safe_pr_auc(valid_dataset["labels"], valid_probs)
if valid_roc > best_valid_roc + 1e-4:
best_valid_roc = valid_roc
best_valid_pr = valid_pr
best_state = copy.deepcopy(self._model.state_dict())
best_epoch = epoch + 1
stale_epochs = 0
else:
stale_epochs += 1
if stale_epochs >= patience:
break
self._model.load_state_dict(best_state)
self._model.eval()
self._constant_prob = None
return {
"best_epoch": int(best_epoch),
"best_valid_roc_auc": float(best_valid_roc) if best_valid_roc > -np.inf else float("nan"),
"best_valid_pr_auc": float(best_valid_pr),
"train_examples": int(len(train_dataset["labels"])),
"valid_examples": int(len(valid_dataset["labels"])) if valid_dataset is not None else 0,
}
def predict_matched_prefix_examples(
self,
df_eval: pd.DataFrame,
examples: pd.DataFrame,
seed: int = 0,
shuffle_within_sequence: bool = False,
batch_size: int = 256,
) -> np.ndarray:
dataset = self._build_matched_example_dataset(
df_eval,
examples,
shuffle_within_sequence=shuffle_within_sequence,
seed=seed,
)
return self._predict_dataset(dataset, batch_size=batch_size)
def train_node_classifier_on_prefix(
self,
df_prefix: pd.DataFrame,
eval_nodes: List[int],
y_labels: np.ndarray,
num_epochs: int = 150,
) -> None:
assert self._model is not None, "Call fit() first."
y = np.asarray(y_labels, dtype=np.float32)
if len(y) == 0 or len(np.unique(y)) < 2:
self._constant_prob = float(y.mean()) if len(y) else 0.0
return
receiver_ids, numeric_feats, positions, lengths = self._build_sequences(df_prefix, eval_nodes)
y_t = torch.tensor(y, dtype=torch.float32, device=self.device)
pos_weight = torch.clamp((y_t == 0).sum() / ((y_t == 1).sum() + 1e-6), max=10.0)
loss_fn = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
optimizer = torch.optim.Adam(self._model.parameters(), lr=1e-3)
n_epochs = max(24, min(64, max(1, num_epochs // 2)))
self._model.train()
for _ in range(n_epochs):
logits = self._model(receiver_ids, numeric_feats, positions, lengths)
loss = loss_fn(logits, y_t)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self._model.parameters(), 1.0)
optimizer.step()
self._constant_prob = None
self._model.eval()
def predict(self, df_eval: pd.DataFrame, eval_nodes: List[int]) -> np.ndarray:
if self._constant_prob is not None:
return np.full(len(eval_nodes), self._constant_prob, dtype=np.float32)
assert self._model is not None, "Call fit() first."
receiver_ids, numeric_feats, positions, lengths = self._build_sequences(df_eval, eval_nodes)
self._model.eval()
with torch.no_grad():
logits = self._model(receiver_ids, numeric_feats, positions, lengths)
probs = torch.sigmoid(logits).cpu().numpy()
return probs.astype(np.float32)
def reset_memory(self) -> None:
pass