import pandas as pd import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib import cm import gravis as gv import networkx as nx from sklearn.compose import ColumnTransformer from sklearn.pipeline import Pipeline from sklearn.impute import SimpleImputer from sklearn.preprocessing import StandardScaler, OneHotEncoder def get_only_features_names(name): return name[5:] def clean_csv_file(csv_file): df = pd.read_csv(csv_file) num_cols = df.select_dtypes(include=['float64', 'int64']).columns cat_cols = df.select_dtypes(include=['object']).columns numeric_transformer = Pipeline(steps=[ ('imputer', SimpleImputer(strategy='mean')), ('scaler', StandardScaler()) ]) categorical_transformer = Pipeline(steps=[ ('imputer', SimpleImputer(strategy='most_frequent')), ('onehot', OneHotEncoder(handle_unknown='ignore')) ]) preprocessor = ColumnTransformer( transformers=[ ('num', numeric_transformer, num_cols), ('cat', categorical_transformer, cat_cols) ]) pipeline = Pipeline(steps=[('preprocessor', preprocessor)]) transform = pipeline.fit_transform(df) df_transform = pd.DataFrame(data=transform, columns=list(map(get_only_features_names, preprocessor.get_feature_names_out()))) return df_transform def build_graph(csv_file, *, threshold, corr_type): features = clean_csv_file(csv_file) links = features.corr(method=corr_type) links = links.fillna(0).stack().reset_index() links.columns = ['var_1', 'var_2', 'corr_val'] links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])] iter_values = iter(set(links_filtered['corr_val'])) G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2') for node, data in G.nodes(data=True): data['node_identifier'] = node for _, _, data in G.edges(data=True): data['corr_value'] = next(iter_values) return G class MplColorHelper: def __init__(self, cmap_name, start_val, stop_val): self.cmap_name = cmap_name self.cmap = plt.get_cmap(cmap_name) self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val) self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) def get_rgba(self, val): return self.scalarMap.to_rgba(val, bytes=True) def get_rgb_str(self, val): r, g, b, a = self.get_rgba(val) return f"rgb({r},{g},{b})" def display_graph(csv_file, *, threshold, corr_type): G = build_graph(csv_file, threshold=threshold, corr_type=corr_type) CM_NAME = "Wistia" vals = nx.get_edge_attributes(G, 'corr_value').values() val_min, val_max = min(vals), max(vals) edge_colors = MplColorHelper(CM_NAME, val_min, val_max) # get rgb string for each node for u, v, data in G.edges(data=True): data['color'] = edge_colors.get_rgb_str(data['corr_value']) disp = gv.d3( G, # graph specs graph_height=500, # node specs node_size_data_source="betweenness_centrality", show_node_label=True, node_label_data_source='node_identifier', # edge specs edge_size_data_source='corr_value', use_edge_size_normalization=True, edge_size_normalization_min=0.3, edge_size_normalization_max=6, # force-directed graph specs edge_curvature=0.4, use_centering_force=True, many_body_force_strength=-300, ) return disp.to_html()