import pandas as pd import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib import cm import gravis as gv import networkx as nx def clean_csv_file(csv_file): df = pd.read_csv(csv_file) return df def build_graph(csv_file, threshold): features = clean_csv_file(csv_file) links = features.corr(method='kendall').fillna(0).stack().reset_index() links.columns = ['var1', 'var2', 'value'] # Keep only correlation over a threshold links_filtered = links.loc[(links['value'] > threshold) & (links['var1'] != links['var2'])] iter_values = iter(set(links_filtered['value'])) G = nx.from_pandas_edgelist(links_filtered, 'var1', 'var2') for node, data in G.nodes(data=True): data['node_identifier'] = node for _, _, data in G.edges(data=True): data['corr_value'] = next(iter_values) return G class MplColorHelper: def __init__(self, cmap_name, start_val, stop_val): self.cmap_name = cmap_name self.cmap = plt.get_cmap(cmap_name) self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val) self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) def get_rgba(self, val): return self.scalarMap.to_rgba(val, bytes=True) def get_rgb_str(self, val): r, g, b, a = self.get_rgba(val) return f"rgb({r},{g},{b})" def display_graph(csv_file, threshold): G = build_graph(csv_file, threshold=threshold) CM_NAME = "Wistia" vals = nx.get_edge_attributes(G, 'corr_value').values() val_min, val_max = min(vals), max(vals) edge_colors = MplColorHelper(CM_NAME, val_min, val_max) # get rgb string for each node for u, v, data in G.edges(data=True): data['color'] = edge_colors.get_rgb_str(data['corr_value']) disp = gv.d3( G, # graph specs graph_height=500, # node specs node_size_data_source="betweenness_centrality", show_node_label=True, node_label_data_source='node_identifier', # edge specs edge_size_data_source='corr_value', use_edge_size_normalization=True, edge_size_normalization_min=0.3, edge_size_normalization_max=6, # force-directed graph specs many_body_force_strength=-500, ) return disp.to_html()