File size: 3,675 Bytes
3e69ac8
 
 
 
 
 
1c6ad85
 
 
 
3e69ac8
 
1c6ad85
 
 
3e69ac8
1c6ad85
3e69ac8
775aa8b
1c6ad85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775aa8b
1c6ad85
3e69ac8
 
775aa8b
1f95992
3e69ac8
1f95992
 
 
3e69ac8
1f95992
 
 
3e69ac8
1f95992
3e69ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
775aa8b
3e69ac8
775aa8b
3e69ac8
 
 
 
 
 
 
 
 
 
775aa8b
3e69ac8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a7d5430
9c571c6
a7d5430
3e69ac8
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import gravis as gv
import networkx as nx
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder


def get_only_features_names(name):
    return name[5:]

def clean_csv_file(csv_file):

    df = pd.read_csv(csv_file)

    num_cols = df.select_dtypes(include=['float64', 'int64']).columns
    cat_cols = df.select_dtypes(include=['object']).columns

    numeric_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='mean')),
        ('scaler', StandardScaler())
    ])

    categorical_transformer = Pipeline(steps=[
        ('imputer', SimpleImputer(strategy='most_frequent')),
        ('onehot', OneHotEncoder(handle_unknown='ignore'))
    ])

    preprocessor = ColumnTransformer(
        transformers=[
            ('num', numeric_transformer, num_cols),
            ('cat', categorical_transformer, cat_cols)
            ])

    pipeline = Pipeline(steps=[('preprocessor', preprocessor)])

    transform = pipeline.fit_transform(df)

    df_transform = pd.DataFrame(data=transform, columns=list(map(get_only_features_names, preprocessor.get_feature_names_out())))

    return df_transform


def build_graph(csv_file, *, threshold, corr_type):
    
    features = clean_csv_file(csv_file)
    links = features.corr(method=corr_type)
    links = links.fillna(0).stack().reset_index()
    links.columns = ['var_1', 'var_2', 'corr_val']

    links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])]
    
    iter_values = iter(set(links_filtered['corr_val']))

    G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2')

    for node, data in G.nodes(data=True):
        data['node_identifier'] = node

    for _, _, data in G.edges(data=True):
        data['corr_value'] = next(iter_values)

    return G


class MplColorHelper:
    def __init__(self, cmap_name, start_val, stop_val):
        self.cmap_name = cmap_name
        self.cmap = plt.get_cmap(cmap_name)
        self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
        self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)

    def get_rgba(self, val):
        return self.scalarMap.to_rgba(val, bytes=True)

    def get_rgb_str(self, val):
        r, g, b, a = self.get_rgba(val)
        return f"rgb({r},{g},{b})"
    

def display_graph(csv_file, *, threshold, corr_type):

    G = build_graph(csv_file, threshold=threshold, corr_type=corr_type)

    CM_NAME = "Wistia"
 
    vals = nx.get_edge_attributes(G, 'corr_value').values()
    val_min, val_max = min(vals), max(vals)
    edge_colors = MplColorHelper(CM_NAME, val_min, val_max)

    # get rgb string for each node
    for u, v, data in G.edges(data=True):
        data['color'] = edge_colors.get_rgb_str(data['corr_value'])
        
    disp = gv.d3(
            G, 
            # graph specs
            graph_height=500,
            
            # node specs
            node_size_data_source="betweenness_centrality",
            show_node_label=True,
            node_label_data_source='node_identifier',
            
            # edge specs
            edge_size_data_source='corr_value',
            use_edge_size_normalization=True,
            edge_size_normalization_min=0.3,
            edge_size_normalization_max=6,

            # force-directed graph specs
            edge_curvature=0.4,
            use_centering_force=True,
            many_body_force_strength=-300,
        )

    return disp.to_html()