File size: 2,293 Bytes
84db192
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import torch.nn as nn
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
import os
script_path=os.path.abspath(__file__)
script_dir=os.path.dirname(script_path)
os.chdir(script_dir)
class SimpleNN(nn.Module):
    def __init__(self, input_dim):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(input_dim, 100)
        self.dropout1 = nn.Dropout(0.5)
        self.fc2 = nn.Linear(100, 100)
        self.dropout2 = nn.Dropout(0.5)
        self.fc3 = nn.Linear(100, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.dropout1(x)
        x = torch.relu(self.fc2(x))
        x = self.dropout2(x)
        x = self.fc3(x)
        return x


input_dim = 51  
model = SimpleNN(input_dim)
model.load_state_dict(torch.load('best_model.pth'))
model.eval()


weights = []
weights.append(model.fc1.weight.detach().numpy())
weights.append(model.fc2.weight.detach().numpy())
weights.append(model.fc3.weight.detach().numpy())


layers = [input_dim, 100, 100, 1]


def draw_neural_network(layers, weights):
    G = nx.Graph()

    
    pos = {}
    layer_nodes = []
    for i, num_nodes in enumerate(layers):
        layer_nodes.append([])
        for j in range(num_nodes):
            node_name = f'L{i}_N{j}'
            layer_nodes[-1].append(node_name)
            pos[node_name] = (i, -j + num_nodes // 2)  

    
    edges = []
    edge_colors = []
    for i in range(len(layers) - 1):
        for j, node in enumerate(layer_nodes[i]):
            for k, next_node in enumerate(layer_nodes[i+1]):
                weight = weights[i][k, j]
                edges.append((node, next_node))
                edge_colors.append(weight)

    G.add_edges_from(edges)

    
    plt.figure(figsize=(10, 10))
    nx.draw(G, pos, with_labels=False, node_size=700, node_color='lightblue', 
            edge_color=edge_colors, edge_cmap=plt.cm.viridis, 
            width=2, edge_vmin=min(edge_colors), edge_vmax=max(edge_colors))

    
    for key, value in pos.items():
        plt.text(value[0], value[1] + 0.1, key, ha='center', va='center')

    plt.title("Neural Network Visualization")
    plt.show()

draw_neural_network(layers, weights)