asammoud
Re-add large CSVs using Git LFS
b265364
import torch
import torch.nn as nn
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
from pipeline.main import Main
from pipeline.test import test
from pipeline.evaluate import get_full_err_scores
anomaly_node_size = 80
default_node_size = 2
central_node_color = "yellow"
anomaly_node_color = "red"
default_node_color = "black"
anomaly_edge_color = "red"
default_edge_color = (0.35686275, 0.20392157, 0.34901961, 0.1)
train_config = {
'batch': 16,
'epoch': 100,
'slide_win': 5,
'dim': 64,
'slide_stride': 1,
'comment': '',
'seed': 42,
'out_layer_num': 1,
'out_layer_inter_dim': 128,
'decay': 0,
'val_ratio': 0.1,
'topk': 15,
}
env_config = {
'save_path': '',
'dataset': 'swat',
'report': 'best',
'device': 'cpu',
'load_model_path': ''
}
def compute_graph(model: nn.Module, X: torch.Tensor):
n_samples, feature_num, slide_win = X.shape
with torch.no_grad():
model(X, None)
coeff_weights = model.gnn_layers[0].att_weight_1.cpu().detach().numpy()
edge_index = model.gnn_layers[0].edge_index_1.cpu().detach().numpy()
weight_mat = np.zeros((feature_num, feature_num))
for i in range(len(coeff_weights)):
edge_i, edge_j = edge_index[:, i]
edge_i, edge_j = edge_i % feature_num, edge_j % feature_num
weight_mat[edge_i][edge_j] += coeff_weights[i]
weight_mat /= n_samples
return weight_mat
def run_gnn(central_node_id="auto"):
device = "cpu"
main = Main(train_config, env_config, debug=False)
model = main.model.to(device)
checkpoint = torch.load("best_05_22_15_03_20.pt", map_location=torch.device(device))
main.model.load_state_dict(checkpoint)
_, train_result = test(model, main.train_dataloader)
_, test_result = test(model, main.test_dataloader)
all_scores, _ = get_full_err_scores(train_result, test_result)
X_train = main.train_dataset.x.float().to(device)
n_samples, feature_num, slide_win = X_train.shape
adj_mat = compute_graph(model, X_train[:100])
if central_node_id == "auto":
central_node = all_scores.mean(axis=1).argmax()
else:
central_node = int(central_node_id)
scores = np.stack([adj_mat[central_node], adj_mat[:, central_node]], axis=1)
scores = np.max(scores, axis=1)
red_nodes = list(np.where(scores > 0.1)[0])
G = nx.from_numpy_array(adj_mat)
G.remove_edges_from(nx.selfloop_edges(G))
edges = [set(edge) for edge in G.edges()]
edge_colors = [default_edge_color for edge in edges]
node_colors = [default_node_color for _ in range(feature_num)]
node_sizes = [default_node_size for _ in range(feature_num)]
node_colors[central_node] = central_node_color
node_sizes[central_node] = anomaly_node_size
for node in red_nodes:
if node == central_node:
continue
node_colors[node] = anomaly_node_color
node_sizes[node] = anomaly_node_size
if set((node, central_node)) in edges:
edge_pos = edges.index(set((node, central_node)))
edge_colors[edge_pos] = anomaly_edge_color
pos = nx.spring_layout(G)
graph_center = np.mean(np.array(list(pos.values())), axis=0)
offset_scale = 0.3
fig, ax = plt.subplots(figsize=(8, 6))
nx.draw(G, pos,
edge_color=edge_colors,
node_color=node_colors,
node_size=node_sizes,
ax=ax)
# Central node label
x, y = pos[central_node]
dx, dy = x - graph_center[0], y - graph_center[1]
norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6
x_offset = x + offset_scale * dx / norm
y_offset = y + offset_scale * dy / norm
ax.text(x_offset, y_offset,
s=main.feature_map[central_node],
bbox=dict(facecolor=central_node_color, alpha=0.5),
horizontalalignment='center')
# Red node labels and dotted lines
for node in red_nodes:
if node == central_node:
continue
x, y = pos[node]
dx, dy = x - graph_center[0], y - graph_center[1]
norm = np.sqrt(dx ** 2 + dy ** 2) + 1e-6
x_offset = x + offset_scale * dx / norm
y_offset = y + offset_scale * dy / norm
ax.plot([x, x_offset], [y, y_offset], 'k--', linewidth=0.8)
ax.text(x_offset, y_offset,
s=main.feature_map[node],
bbox=dict(facecolor=anomaly_node_color, alpha=0.5),
horizontalalignment='center')
fig.tight_layout()
# ? Convert feature map from list to dict for Streamlit compatibility
feature_map_dict = {i: label for i, label in enumerate(main.feature_map)}
return fig, feature_map_dict, red_nodes, central_node, scores, G