File size: 2,420 Bytes
3e69ac8
 
 
 
 
 
 
 
 
 
 
 
 
3ff951d
3e69ac8
 
 
 
 
3ff951d
3e69ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3ff951d
3e69ac8
3ff951d
3e69ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import gravis as gv
import networkx as nx


def clean_csv_file(csv_file):
    df = pd.read_csv(csv_file)
    return df


def build_graph(csv_file, threshold):
    features = clean_csv_file(csv_file)
    links = features.corr(method='kendall').fillna(0).stack().reset_index()
    links.columns = ['var1', 'var2', 'value']

    # Keep only correlation over a threshold
    links_filtered = links.loc[(links['value'] > threshold) & (links['var1'] != links['var2'])]
    iter_values = iter(set(links_filtered['value']))

    G = nx.from_pandas_edgelist(links_filtered, 'var1', 'var2')

    for node, data in G.nodes(data=True):
        data['node_identifier'] = node

    for _, _, data in G.edges(data=True):
        data['corr_value'] = next(iter_values)

    return G


class MplColorHelper:
    def __init__(self, cmap_name, start_val, stop_val):
        self.cmap_name = cmap_name
        self.cmap = plt.get_cmap(cmap_name)
        self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
        self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)

    def get_rgba(self, val):
        return self.scalarMap.to_rgba(val, bytes=True)

    def get_rgb_str(self, val):
        r, g, b, a = self.get_rgba(val)
        return f"rgb({r},{g},{b})"
    

def display_graph(csv_file, threshold):

    G = build_graph(csv_file, threshold=threshold)

    CM_NAME = "Wistia"
 
    vals = nx.get_edge_attributes(G, 'corr_value').values()
    val_min, val_max = min(vals), max(vals)
    edge_colors = MplColorHelper(CM_NAME, val_min, val_max)

    # get rgb string for each node
    for u, v, data in G.edges(data=True):
        data['color'] = edge_colors.get_rgb_str(data['corr_value'])
    disp = gv.d3(
            G, 
            # graph specs
            graph_height=500,
            
            # node specs
            node_size_data_source="betweenness_centrality",
            show_node_label=True,
            node_label_data_source='node_identifier',
            
            # edge specs
            edge_size_data_source='corr_value',
            use_edge_size_normalization=True,
            edge_size_normalization_min=0.3,
            edge_size_normalization_max=6,

            # force-directed graph specs
            many_body_force_strength=-500,
        )

    return disp.to_html()