correlationGraph / utils.py
thov's picture
add threshold on correlation values
3ff951d
raw
history blame
No virus
2.42 kB
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import gravis as gv
import networkx as nx
def clean_csv_file(csv_file):
df = pd.read_csv(csv_file)
return df
def build_graph(csv_file, threshold):
features = clean_csv_file(csv_file)
links = features.corr(method='kendall').fillna(0).stack().reset_index()
links.columns = ['var1', 'var2', 'value']
# Keep only correlation over a threshold
links_filtered = links.loc[(links['value'] > threshold) & (links['var1'] != links['var2'])]
iter_values = iter(set(links_filtered['value']))
G = nx.from_pandas_edgelist(links_filtered, 'var1', 'var2')
for node, data in G.nodes(data=True):
data['node_identifier'] = node
for _, _, data in G.edges(data=True):
data['corr_value'] = next(iter_values)
return G
class MplColorHelper:
def __init__(self, cmap_name, start_val, stop_val):
self.cmap_name = cmap_name
self.cmap = plt.get_cmap(cmap_name)
self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
def get_rgba(self, val):
return self.scalarMap.to_rgba(val, bytes=True)
def get_rgb_str(self, val):
r, g, b, a = self.get_rgba(val)
return f"rgb({r},{g},{b})"
def display_graph(csv_file, threshold):
G = build_graph(csv_file, threshold=threshold)
CM_NAME = "Wistia"
vals = nx.get_edge_attributes(G, 'corr_value').values()
val_min, val_max = min(vals), max(vals)
edge_colors = MplColorHelper(CM_NAME, val_min, val_max)
# get rgb string for each node
for u, v, data in G.edges(data=True):
data['color'] = edge_colors.get_rgb_str(data['corr_value'])
disp = gv.d3(
G,
# graph specs
graph_height=500,
# node specs
node_size_data_source="betweenness_centrality",
show_node_label=True,
node_label_data_source='node_identifier',
# edge specs
edge_size_data_source='corr_value',
use_edge_size_normalization=True,
edge_size_normalization_min=0.3,
edge_size_normalization_max=6,
# force-directed graph specs
many_body_force_strength=-500,
)
return disp.to_html()