Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import matplotlib as mpl | |
| from matplotlib import cm | |
| import gravis as gv | |
| import networkx as nx | |
| from sklearn.compose import ColumnTransformer | |
| from sklearn.pipeline import Pipeline | |
| from sklearn.impute import SimpleImputer | |
| from sklearn.preprocessing import StandardScaler, OneHotEncoder | |
| def get_only_features_names(name): | |
| return name[5:] | |
| def clean_csv_file(csv_file): | |
| df = pd.read_csv(csv_file) | |
| num_cols = df.select_dtypes(include=['float64', 'int64']).columns | |
| cat_cols = df.select_dtypes(include=['object']).columns | |
| numeric_transformer = Pipeline(steps=[ | |
| ('imputer', SimpleImputer(strategy='mean')), | |
| ('scaler', StandardScaler()) | |
| ]) | |
| categorical_transformer = Pipeline(steps=[ | |
| ('imputer', SimpleImputer(strategy='most_frequent')), | |
| ('onehot', OneHotEncoder(handle_unknown='ignore')) | |
| ]) | |
| preprocessor = ColumnTransformer( | |
| transformers=[ | |
| ('num', numeric_transformer, num_cols), | |
| ('cat', categorical_transformer, cat_cols) | |
| ]) | |
| pipeline = Pipeline(steps=[('preprocessor', preprocessor)]) | |
| transform = pipeline.fit_transform(df) | |
| df_transform = pd.DataFrame(data=transform, columns=list(map(get_only_features_names, preprocessor.get_feature_names_out()))) | |
| return df_transform | |
| def build_graph(csv_file, *, threshold, corr_type): | |
| features = clean_csv_file(csv_file) | |
| links = features.corr(method=corr_type) | |
| links = links.fillna(0).stack().reset_index() | |
| links.columns = ['var_1', 'var_2', 'corr_val'] | |
| links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])] | |
| iter_values = iter(set(links_filtered['corr_val'])) | |
| G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2') | |
| for node, data in G.nodes(data=True): | |
| data['node_identifier'] = node | |
| for _, _, data in G.edges(data=True): | |
| data['corr_value'] = next(iter_values) | |
| return G | |
| class MplColorHelper: | |
| def __init__(self, cmap_name, start_val, stop_val): | |
| self.cmap_name = cmap_name | |
| self.cmap = plt.get_cmap(cmap_name) | |
| self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val) | |
| self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap) | |
| def get_rgba(self, val): | |
| return self.scalarMap.to_rgba(val, bytes=True) | |
| def get_rgb_str(self, val): | |
| r, g, b, a = self.get_rgba(val) | |
| return f"rgb({r},{g},{b})" | |
| def display_graph(csv_file, *, threshold, corr_type): | |
| G = build_graph(csv_file, threshold=threshold, corr_type=corr_type) | |
| CM_NAME = "Wistia" | |
| vals = nx.get_edge_attributes(G, 'corr_value').values() | |
| val_min, val_max = min(vals), max(vals) | |
| edge_colors = MplColorHelper(CM_NAME, val_min, val_max) | |
| # get rgb string for each node | |
| for u, v, data in G.edges(data=True): | |
| data['color'] = edge_colors.get_rgb_str(data['corr_value']) | |
| disp = gv.d3( | |
| G, | |
| # graph specs | |
| graph_height=500, | |
| # node specs | |
| node_size_data_source="betweenness_centrality", | |
| show_node_label=True, | |
| node_label_data_source='node_identifier', | |
| # edge specs | |
| edge_size_data_source='corr_value', | |
| use_edge_size_normalization=True, | |
| edge_size_normalization_min=0.3, | |
| edge_size_normalization_max=6, | |
| # force-directed graph specs | |
| edge_curvature=0.4, | |
| use_centering_force=True, | |
| many_body_force_strength=-300, | |
| ) | |
| return disp.to_html() |