correlationGraph / utils.py
thov's picture
add cat features
1c6ad85
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import cm
import gravis as gv
import networkx as nx
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
def get_only_features_names(name):
return name[5:]
def clean_csv_file(csv_file):
df = pd.read_csv(csv_file)
num_cols = df.select_dtypes(include=['float64', 'int64']).columns
cat_cols = df.select_dtypes(include=['object']).columns
numeric_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='mean')),
('scaler', StandardScaler())
])
categorical_transformer = Pipeline(steps=[
('imputer', SimpleImputer(strategy='most_frequent')),
('onehot', OneHotEncoder(handle_unknown='ignore'))
])
preprocessor = ColumnTransformer(
transformers=[
('num', numeric_transformer, num_cols),
('cat', categorical_transformer, cat_cols)
])
pipeline = Pipeline(steps=[('preprocessor', preprocessor)])
transform = pipeline.fit_transform(df)
df_transform = pd.DataFrame(data=transform, columns=list(map(get_only_features_names, preprocessor.get_feature_names_out())))
return df_transform
def build_graph(csv_file, *, threshold, corr_type):
features = clean_csv_file(csv_file)
links = features.corr(method=corr_type)
links = links.fillna(0).stack().reset_index()
links.columns = ['var_1', 'var_2', 'corr_val']
links_filtered = links.loc[(links['corr_val'] > threshold) & (links['var_1'] != links['var_2'])]
iter_values = iter(set(links_filtered['corr_val']))
G = nx.from_pandas_edgelist(links_filtered, 'var_1', 'var_2')
for node, data in G.nodes(data=True):
data['node_identifier'] = node
for _, _, data in G.edges(data=True):
data['corr_value'] = next(iter_values)
return G
class MplColorHelper:
def __init__(self, cmap_name, start_val, stop_val):
self.cmap_name = cmap_name
self.cmap = plt.get_cmap(cmap_name)
self.norm = mpl.colors.Normalize(vmin=start_val, vmax=stop_val)
self.scalarMap = cm.ScalarMappable(norm=self.norm, cmap=self.cmap)
def get_rgba(self, val):
return self.scalarMap.to_rgba(val, bytes=True)
def get_rgb_str(self, val):
r, g, b, a = self.get_rgba(val)
return f"rgb({r},{g},{b})"
def display_graph(csv_file, *, threshold, corr_type):
G = build_graph(csv_file, threshold=threshold, corr_type=corr_type)
CM_NAME = "Wistia"
vals = nx.get_edge_attributes(G, 'corr_value').values()
val_min, val_max = min(vals), max(vals)
edge_colors = MplColorHelper(CM_NAME, val_min, val_max)
# get rgb string for each node
for u, v, data in G.edges(data=True):
data['color'] = edge_colors.get_rgb_str(data['corr_value'])
disp = gv.d3(
G,
# graph specs
graph_height=500,
# node specs
node_size_data_source="betweenness_centrality",
show_node_label=True,
node_label_data_source='node_identifier',
# edge specs
edge_size_data_source='corr_value',
use_edge_size_normalization=True,
edge_size_normalization_min=0.3,
edge_size_normalization_max=6,
# force-directed graph specs
edge_curvature=0.4,
use_centering_force=True,
many_body_force_strength=-300,
)
return disp.to_html()