| import torch |
| from torch_geometric.data import Data |
| from collections import defaultdict |
|
|
| class TransactionGraphBuilder: |
| def __init__(self): |
| self.node_index = defaultdict(int) |
| self.current_id = 0 |
| self.edges = [] |
| self.node_features = [] |
| self.node_types = [] |
| |
| def get_node_id(self, node_key, node_type): |
| if node_key not in self.node_index: |
| self.node_index[node_key] = self.current_id |
| self.current_id += 1 |
| |
| self.node_features.append([1.0 if i == node_type else 0.0 for i in range(3)]) |
| self.node_types.append(node_type) |
| return self.node_index[node_key] |
| |
| def add_transaction(self, transaction): |
| |
| acc_id = self.get_node_id(transaction['AccountID'], 0) |
| |
| merchant_id = self.get_node_id(transaction['MerchantID'], 1) |
| |
| device_id = self.get_node_id(transaction['DeviceID'], 2) |
| |
| |
| self.edges.append((acc_id, merchant_id)) |
| self.edges.append((acc_id, device_id)) |
| |
| |
| edge_index = torch.tensor(list(zip(*self.edges)), dtype=torch.long) |
| x = torch.tensor(self.node_features, dtype=torch.float) |
| |
| return Data(x=x, edge_index=edge_index) |