File size: 4,740 Bytes
b265364 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import torch
import torch.nn as nn
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from pipeline.main import Main
from pipeline.test import test
from pipeline.evaluate import get_full_err_scores
anomaly_node_size = 80
default_node_size = 2
central_node_color = "yellow"
anomaly_node_color = "red"
default_node_color = "black"
anomaly_edge_color = "red"
default_edge_color = (0.35686275, 0.20392157, 0.34901961, 0.1)
train_config = {
'batch': 16,
'epoch': 100,
'slide_win': 5,
'dim': 64,
'slide_stride': 1,
'comment': '',
'seed': 42,
'out_layer_num': 1,
'out_layer_inter_dim': 128,
'decay': 0,
'val_ratio': 0.1,
'topk': 15,
}
env_config = {
'save_path': '',
'dataset': 'swat',
'report': 'best',
'device': 'cpu',
'load_model_path': ''
}
def compute_graph(model: nn.Module, X: torch.Tensor):
n_samples, feature_num, slide_win = X.shape
with torch.no_grad():
model(X, None)
coeff_weights = model.gnn_layers[0].att_weight_1.cpu().detach().numpy()
edge_index = model.gnn_layers[0].edge_index_1.cpu().detach().numpy()
weight_mat = np.zeros((feature_num, feature_num))
for i in range(len(coeff_weights)):
edge_i, edge_j = edge_index[:, i]
edge_i, edge_j = edge_i % feature_num, edge_j % feature_num
weight_mat[edge_i][edge_j] += coeff_weights[i]
weight_mat /= n_samples
return weight_mat
def run_gnn(central_node_id="auto"):
device = "cpu"
main = Main(train_config, env_config, debug=False)
model = main.model.to(device)
checkpoint = torch.load("best_05_22_15_03_20.pt", map_location=torch.device(device))
main.model.load_state_dict(checkpoint)
_, train_result = test(model, main.train_dataloader)
_, test_result = test(model, main.test_dataloader)
all_scores, _ = get_full_err_scores(train_result, test_result)
X_train = main.train_dataset.x.float().to(device)
n_samples, feature_num, slide_win = X_train.shape
adj_mat = compute_graph(model, X_train[:100])
if central_node_id == "auto":
central_node = all_scores.mean(axis=1).argmax()
else:
central_node = int(central_node_id)
scores = np.stack([adj_mat[central_node], adj_mat[:, central_node]], axis=1)
scores = np.max(scores, axis=1)
red_nodes = list(np.where(scores > 0.1)[0])
G = nx.from_numpy_array(adj_mat)
G.remove_edges_from(nx.selfloop_edges(G))
edges = [set(edge) for edge in G.edges()]
edge_colors = [default_edge_color for edge in edges]
node_colors = [default_node_color for _ in range(feature_num)]
node_sizes = [default_node_size for _ in range(feature_num)]
node_colors[central_node] = central_node_color
node_sizes[central_node] = anomaly_node_size
for node in red_nodes:
if node == central_node:
continue
node_colors[node] = anomaly_node_color
node_sizes[node] = anomaly_node_size
if set((node, central_node)) in edges:
edge_pos = edges.index(set((node, central_node)))
edge_colors[edge_pos] = anomaly_edge_color
pos = nx.spring_layout(G)
graph_center = np.mean(np.array(list(pos.values())), axis=0)
offset_scale = 0.3
fig, ax = plt.subplots(figsize=(8, 6))
nx.draw(G, pos,
edge_color=edge_colors,
node_color=node_colors,
node_size=node_sizes,
ax=ax)
# Central node label
x, y = pos[central_node]
dx, dy = x - graph_center[0], y - graph_center[1]
norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6
x_offset = x + offset_scale * dx / norm
y_offset = y + offset_scale * dy / norm
ax.text(x_offset, y_offset,
s=main.feature_map[central_node],
bbox=dict(facecolor=central_node_color, alpha=0.5),
horizontalalignment='center')
# Red node labels and dotted lines
for node in red_nodes:
if node == central_node:
continue
x, y = pos[node]
dx, dy = x - graph_center[0], y - graph_center[1]
norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6
x_offset = x + offset_scale * dx / norm
y_offset = y + offset_scale * dy / norm
ax.plot([x, x_offset], [y, y_offset], 'k--', linewidth=0.8)
ax.text(x_offset, y_offset,
s=main.feature_map[node],
bbox=dict(facecolor=anomaly_node_color, alpha=0.5),
horizontalalignment='center')
fig.tight_layout()
# ? Convert feature map from list to dict for Streamlit compatibility
feature_map_dict = {i: label for i, label in enumerate(main.feature_map)}
return fig, feature_map_dict, red_nodes, central_node, scores, G
|