Spaces:
Sleeping
Sleeping
File size: 3,675 Bytes
3e69ac8 1c6ad85 3e69ac8 1c6ad85 3e69ac8 1c6ad85 3e69ac8 775aa8b 1c6ad85 775aa8b 1c6ad85 3e69ac8 775aa8b 1f95992 3e69ac8 1f95992 3e69ac8 1f95992 3e69ac8 1f95992 3e69ac8 775aa8b 3e69ac8 775aa8b 3e69ac8 775aa8b 3e69ac8 a7d5430 9c571c6 a7d5430 3e69ac8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 |
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import gravis as gv
import networkx as nx
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
def get_only_features_names(name):
return name[5:]
def clean_csv_file(csv_file):
df = pd.read_csv(csv_file)
num_cols = df.select_dtypes(include=['float64', 'int64']).columns
cat_cols = df.select_dtypes(include=['object']).columns
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='mean')),
('scaler', StandardScaler())
])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, num_cols),
('cat', categorical_transformer, cat_cols)
])
pipeline = Pipeline(steps=[('preprocessor', preprocessor)])
transform = pipeline.fit_transform(df)
df_transform = pd.DataFrame(data=transform, columns=list(map(get_only_features_names, preprocessor.get_feature_names_out())))
return df_transform
def build_graph(csv_file, *, threshold, corr_type):
features = clean_csv_file(csv_file)
links = features.corr(method=corr_type)
links = links.fillna(0).stack().reset_index()
links.columns = ['var_1', 'var_2', 'corr_val']
links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])]
iter_values = iter(set(links_filtered['corr_val']))
G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2')
for node, data in G.nodes(data=True):
data['node_identifier'] = node
for _, _, data in G.edges(data=True):
data['corr_value'] = next(iter_values)
return G
class MplColorHelper:
def __init__(self, cmap_name, start_val, stop_val):
self.cmap_name = cmap_name
self.cmap = plt.get_cmap(cmap_name)
self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
def get_rgba(self, val):
return self.scalarMap.to_rgba(val, bytes=True)
def get_rgb_str(self, val):
r, g, b, a = self.get_rgba(val)
return f"rgb({r},{g},{b})"
def display_graph(csv_file, *, threshold, corr_type):
G = build_graph(csv_file, threshold=threshold, corr_type=corr_type)
CM_NAME = "Wistia"
vals = nx.get_edge_attributes(G, 'corr_value').values()
val_min, val_max = min(vals), max(vals)
edge_colors = MplColorHelper(CM_NAME, val_min, val_max)
# get rgb string for each node
for u, v, data in G.edges(data=True):
data['color'] = edge_colors.get_rgb_str(data['corr_value'])
disp = gv.d3(
G,
# graph specs
graph_height=500,
# node specs
node_size_data_source="betweenness_centrality",
show_node_label=True,
node_label_data_source='node_identifier',
# edge specs
edge_size_data_source='corr_value',
use_edge_size_normalization=True,
edge_size_normalization_min=0.3,
edge_size_normalization_max=6,
# force-directed graph specs
edge_curvature=0.4,
use_centering_force=True,
many_body_force_strength=-300,
)
return disp.to_html() |