Spaces:
Running
Running
import pandas as pd | |
import matplotlib.pyplot as plt | |
import matplotlib as mpl | |
from matplotlib import cm | |
import gravis as gv | |
import networkx as nx | |
def clean_csv_file(csv_file): | |
df = pd.read_csv(csv_file) | |
df.dropna(inplace=True) | |
df.drop_duplicates(inplace=True) | |
return df | |
def build_graph(csv_file, *, threshold, corr_type): | |
features = clean_csv_file(csv_file) | |
links = features.corr(method=corr_type) | |
links = links.fillna(0).stack().reset_index() | |
links.columns = ['var_1', 'var_2', 'corr_val'] | |
links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])] | |
iter_values = iter(set(links_filtered['corr_val'])) | |
G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2') | |
for node, data in G.nodes(data=True): | |
data['node_identifier'] = node | |
for _, _, data in G.edges(data=True): | |
data['corr_value'] = next(iter_values) | |
return G | |
class MplColorHelper: | |
def __init__(self, cmap_name, start_val, stop_val): | |
self.cmap_name = cmap_name | |
self.cmap = plt.get_cmap(cmap_name) | |
self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val) | |
self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) | |
def get_rgba(self, val): | |
return self.scalarMap.to_rgba(val, bytes=True) | |
def get_rgb_str(self, val): | |
r, g, b, a = self.get_rgba(val) | |
return f"rgb({r},{g},{b})" | |
def display_graph(csv_file, *, threshold, corr_type): | |
G = build_graph(csv_file, threshold=threshold, corr_type=corr_type) | |
CM_NAME = "Wistia" | |
vals = nx.get_edge_attributes(G, 'corr_value').values() | |
val_min, val_max = min(vals), max(vals) | |
edge_colors = MplColorHelper(CM_NAME, val_min, val_max) | |
# get rgb string for each node | |
for u, v, data in G.edges(data=True): | |
data['color'] = edge_colors.get_rgb_str(data['corr_value']) | |
disp = gv.d3( | |
G, | |
# graph specs | |
graph_height=500, | |
# node specs | |
node_size_data_source="betweenness_centrality", | |
show_node_label=True, | |
node_label_data_source='node_identifier', | |
# edge specs | |
edge_size_data_source='corr_value', | |
use_edge_size_normalization=True, | |
edge_size_normalization_min=0.3, | |
edge_size_normalization_max=6, | |
# force-directed graph specs | |
edge_curvature=0.4, | |
use_centering_force=True, | |
many_body_force_strength=-300, | |
) | |
return disp.to_html() |