AutoML / src /ui /visualization.py
akash
all files
890025a
import re
import streamlit as st
import plotly.express as px
import plotly.graph_objects as go
from plotly.subplots import make_subplots
import pandas as pd
import numpy as np
from src.utils.logging import log_frontend_error, log_frontend_warning
SAMPLE_SIZE = 10000 # Define a sample size for subsampling large datasets
# Efficiently hash a dataframe to detect changes
@st.cache_data(show_spinner=False)
def compute_df_hash(df):
"""Optimized dataframe hashing"""
return hash((df.shape, pd.util.hash_pandas_object(df.iloc[:min(100, len(df))]).sum())) # Sample-based hashing
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
def is_potential_date_column(series, sample_size=5):
"""Check if column might contain dates"""
# Check column name first
if any(keyword in series.name.lower() for keyword in ['date', 'time', 'year', 'month', 'day']):
return True
# Check sample values
sample = series.dropna().head(sample_size).astype(str)
date_patterns = [
r'\d{4}-\d{2}-\d{2}', # YYYY-MM-DD
r'\d{2}/\d{2}/\d{4}', # MM/DD/YYYY
r'\d{2}-\w{3}-\d{2,4}', # DD-MON-YY(Y)
r'\d{1,2} \w{3,} \d{4}' # 1 January 2023
]
date_count = sum(1 for val in sample if any(re.match(p, val) for p in date_patterns))
return date_count / len(sample) > 0.5 if len(sample) > 0 else False # >50% match
# Cache column type detection with improved performance
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
def get_column_types(df):
"""Detect column types efficiently and cache the results."""
column_types = {}
# Process columns in batches for better performance
for chunk_start in range(0, len(df.columns), 10):
chunk_end = min(chunk_start + 10, len(df.columns))
chunk_columns = df.columns[chunk_start:chunk_end]
for column in chunk_columns:
# Check for numeric columns
if pd.api.types.is_numeric_dtype(df[column]):
# Detect if it's a binary column (0/1, True/False)
if df[column].nunique() <= 2:
column_types[column] = "BINARY"
# Detect if it's a discrete numeric column (few unique values)
elif df[column].nunique() < 20:
column_types[column] = "NUMERIC_DISCRETE"
# Otherwise it's a continuous numeric column
else:
column_types[column] = "NUMERIC_CONTINUOUS"
else:
# Check for temporal/date columns
if is_potential_date_column(df[column]):
try:
# Attempt conversion with coerce
converted = pd.to_datetime(df[column], errors='coerce')
if not converted.isnull().all(): # At least some valid dates
column_types[column] = "TEMPORAL"
continue
except Exception:
pass
# Check for ID-like columns (high cardinality with unique patterns)
if (df[column].nunique() > len(df) * 0.9 and
any(x in column.lower() for x in ['id', 'code', 'key', 'uuid', 'identifier'])):
column_types[column] = "ID"
# Check for categorical columns (low to medium cardinality)
elif df[column].nunique() <= 20:
column_types[column] = "CATEGORICAL"
# Otherwise it's a text column
else:
column_types[column] = "TEXT"
return column_types
# Cache correlation matrix computation with improved performance
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
def get_corr_matrix(df):
"""Compute and cache the correlation matrix for numeric columns."""
# Only select numeric columns to avoid errors
numeric_cols = df.select_dtypes(include=[np.number]).columns
# If we have too many numeric columns, sample them for better performance
if len(numeric_cols) > 30:
numeric_cols = numeric_cols[:30]
# Return correlation matrix if we have at least 2 numeric columns
return df[numeric_cols].corr() if len(numeric_cols) > 1 else None
# Cache subsampled data with improved performance
@st.cache_data(show_spinner=False, ttl=3600) # Cache for 1 hour
def get_subsampled_data(df, column):
"""Return subsampled data for faster visualization."""
# Check if column exists
if column not in df.columns:
return pd.DataFrame()
# Use stratified sampling for categorical columns if possible
if df[column].nunique() < 20 and len(df) > SAMPLE_SIZE:
try:
# Try to get a representative sample
fractions = min(0.5, SAMPLE_SIZE / len(df))
return df[[column]].groupby(column, group_keys=False).apply(
lambda x: x.sample(max(1, int(fractions * len(x))), random_state=42)
)
except Exception:
# Fall back to random sampling
pass
# Use random sampling
return df[[column]].sample(min(len(df), SAMPLE_SIZE), random_state=42)
# Cache chart creation with improved performance
@st.cache_data(show_spinner=False, ttl=1800, hash_funcs={ # Cache for 30 minutes
pd.DataFrame: compute_df_hash,
pd.Series: lambda s: hash((s.name, compute_df_hash(s.to_frame())))
})
def create_chart(df, column, column_type):
"""Generate optimized charts based on column type."""
# Check if column exists in the dataframe
if column not in df.columns:
return None
# Get subsampled data for better performance
df_sample = get_subsampled_data(df, column)
if df_sample.empty:
return None
try:
# Year-based columns (special case)
if "year" in column.lower():
fig = make_subplots(rows=1, cols=2, subplot_titles=("Year Distribution", "Box Plot"),
specs=[[{"type": "bar"}, {"type": "box"}]], column_widths=[0.7, 0.3], horizontal_spacing=0.1)
year_counts = df_sample[column].value_counts().sort_index()
fig.add_trace(go.Bar(x=year_counts.index, y=year_counts.values, marker_color='#7B68EE'), row=1, col=1)
fig.add_trace(go.Box(x=df_sample[column], marker_color='#7B68EE'), row=1, col=2)
# Binary columns (0/1, True/False)
elif column_type == "BINARY":
value_counts = df_sample[column].value_counts()
fig = make_subplots(rows=1, cols=2,
subplot_titles=("Distribution", "Percentage"),
specs=[[{"type": "bar"}, {"type": "pie"}]],
column_widths=[0.5, 0.5],
horizontal_spacing=0.1)
fig.add_trace(go.Bar(
x=value_counts.index,
y=value_counts.values,
marker_color=['#FF4B4B', '#4CAF50'],
text=value_counts.values,
textposition='auto'
), row=1, col=1)
fig.add_trace(go.Pie(
labels=value_counts.index,
values=value_counts.values,
marker=dict(colors=['#FF4B4B', '#4CAF50']),
textinfo='percent+label'
), row=1, col=2)
fig.update_layout(title_text=f"Binary Distribution: {column}")
# Numeric continuous columns
elif column_type == "NUMERIC_CONTINUOUS":
fig = make_subplots(rows=2, cols=2,
subplot_titles=("Distribution", "Box Plot", "Violin Plot", "Cumulative Distribution"),
specs=[[{"type": "histogram"}, {"type": "box"}],
[{"type": "violin"}, {"type": "scatter"}]],
vertical_spacing=0.15,
horizontal_spacing=0.1)
# Histogram
fig.add_trace(go.Histogram(
x=df_sample[column],
nbinsx=30,
marker_color='#FF4B4B',
opacity=0.7
), row=1, col=1)
# Box plot
fig.add_trace(go.Box(
x=df_sample[column],
marker_color='#FF4B4B',
boxpoints='outliers'
), row=1, col=2)
# Violin plot
fig.add_trace(go.Violin(
x=df_sample[column],
marker_color='#FF4B4B',
box_visible=True,
points='outliers'
), row=2, col=1)
# CDF
sorted_data = np.sort(df_sample[column].dropna())
cumulative = np.arange(1, len(sorted_data) + 1) / len(sorted_data)
fig.add_trace(go.Scatter(
x=sorted_data,
y=cumulative,
mode='lines',
line=dict(color='#FF4B4B', width=2)
), row=2, col=2)
fig.update_layout(height=600, title_text=f"Continuous Variable Analysis: {column}")
# Numeric discrete columns
elif column_type == "NUMERIC_DISCRETE":
value_counts = df_sample[column].value_counts().sort_index()
fig = make_subplots(rows=1, cols=2,
subplot_titles=("Distribution", "Percentage"),
specs=[[{"type": "bar"}, {"type": "pie"}]],
column_widths=[0.7, 0.3],
horizontal_spacing=0.1)
fig.add_trace(go.Bar(
x=value_counts.index,
y=value_counts.values,
marker_color='#FF4B4B',
text=value_counts.values,
textposition='auto'
), row=1, col=1)
fig.add_trace(go.Pie(
labels=value_counts.index,
values=value_counts.values,
marker=dict(colors=px.colors.sequential.Reds),
textinfo='percent+label'
), row=1, col=2)
fig.update_layout(title_text=f"Discrete Numeric Distribution: {column}")
# Categorical columns
elif column_type == "CATEGORICAL":
value_counts = df_sample[column].value_counts().head(20) # Limit to top 20 categories
fig = make_subplots(rows=1, cols=2,
subplot_titles=("Category Distribution", "Percentage Breakdown"),
specs=[[{"type": "bar"}, {"type": "pie"}]],
column_widths=[0.6, 0.4],
horizontal_spacing=0.1)
# Bar chart
fig.add_trace(go.Bar(
x=value_counts.index,
y=value_counts.values,
marker_color='#00FFA3',
text=value_counts.values,
textposition='auto'
), row=1, col=1)
# Pie chart
fig.add_trace(go.Pie(
labels=value_counts.index,
values=value_counts.values,
marker=dict(colors=px.colors.sequential.Greens),
textinfo='percent+label'
), row=1, col=2)
fig.update_layout(title_text=f"Categorical Analysis: {column}")
# Temporal/date columns
elif column_type == "TEMPORAL":
# Convert with safe datetime parsing
dates = pd.to_datetime(df_sample[column], errors='coerce', format='mixed')
valid_dates = dates[dates.notna()]
fig = make_subplots(
rows=2,
cols=2,
subplot_titles=("Monthly Pattern", "Yearly Pattern", "Cumulative Trend", "Day of Week Distribution"),
vertical_spacing=0.15,
horizontal_spacing=0.1,
specs=[[{"type": "bar"}, {"type": "bar"}],
[{"type": "scatter"}, {"type": "bar"}]]
)
# Monthly pattern
if not valid_dates.empty:
monthly_counts = valid_dates.dt.month.value_counts().sort_index()
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun', 'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
month_labels = [month_names[i-1] for i in monthly_counts.index]
fig.add_trace(go.Bar(
x=month_labels,
y=monthly_counts.values,
marker_color='#7B68EE',
text=monthly_counts.values,
textposition='auto'
), row=1, col=1)
# Yearly pattern
yearly_counts = valid_dates.dt.year.value_counts().sort_index()
fig.add_trace(go.Bar(
x=yearly_counts.index,
y=yearly_counts.values,
marker_color='#7B68EE',
text=yearly_counts.values,
textposition='auto'
), row=1, col=2)
# Cumulative trend
sorted_dates = valid_dates.sort_values()
cumulative = np.arange(1, len(sorted_dates) + 1)
fig.add_trace(go.Scatter(
x=sorted_dates,
y=cumulative,
mode='lines',
line=dict(color='#7B68EE', width=2)
), row=2, col=1)
# Day of week distribution
dow_counts = valid_dates.dt.dayofweek.value_counts().sort_index()
dow_names = ['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun']
dow_labels = [dow_names[i] for i in dow_counts.index]
fig.add_trace(go.Bar(
x=dow_labels,
y=dow_counts.values,
marker_color='#7B68EE',
text=dow_counts.values,
textposition='auto'
), row=2, col=2)
fig.update_layout(height=600, title_text=f"Temporal Analysis: {column}")
# ID columns (show distribution of first few characters, length distribution)
elif column_type == "ID":
# Calculate ID length statistics
id_lengths = df_sample[column].astype(str).str.len()
# Extract first 2 characters for prefix analysis
id_prefixes = df_sample[column].astype(str).str[:2].value_counts().head(15)
fig = make_subplots(
rows=1,
cols=2,
subplot_titles=("ID Length Distribution", "Common ID Prefixes"),
horizontal_spacing=0.1,
specs=[[{"type": "histogram"}, {"type": "bar"}]]
)
# ID length histogram
fig.add_trace(go.Histogram(
x=id_lengths,
nbinsx=20,
marker_color='#9C27B0'
), row=1, col=1)
# ID prefix bar chart
fig.add_trace(go.Bar(
x=id_prefixes.index,
y=id_prefixes.values,
marker_color='#9C27B0',
text=id_prefixes.values,
textposition='auto'
), row=1, col=2)
fig.update_layout(title_text=f"ID Analysis: {column}")
# Text columns
elif column_type == "TEXT":
# For text columns, show top values and length distribution
value_counts = df_sample[column].value_counts().head(15)
# Calculate text length statistics
text_lengths = df_sample[column].astype(str).str.len()
fig = make_subplots(
rows=2,
cols=1,
subplot_titles=("Top Values", "Text Length Distribution"),
vertical_spacing=0.2,
specs=[[{"type": "bar"}], [{"type": "histogram"}]]
)
# Top values bar chart
fig.add_trace(
go.Bar(
x=value_counts.index,
y=value_counts.values,
marker_color='#00B4D8',
text=value_counts.values,
textposition='auto'
),
row=1, col=1
)
# Text length histogram
fig.add_trace(
go.Histogram(
x=text_lengths,
nbinsx=20,
marker_color='#00B4D8'
),
row=2, col=1
)
fig.update_layout(
height=600,
title_text=f"Text Analysis: {column}"
)
# Fallback for any other column type
else:
fig = go.Figure(go.Histogram(x=df_sample[column], marker_color='#888'))
fig.update_layout(title_text=f"Generic Analysis: {column}")
# Common layout settings
fig.update_layout(
height=400,
showlegend=False,
plot_bgcolor='rgba(0,0,0,0)',
paper_bgcolor='rgba(0,0,0,0)',
font=dict(color='#FFFFFF'),
margin=dict(l=40, r=40, t=50, b=40)
)
return fig
except Exception as e:
log_frontend_error("Chart Generation", f"Error creating chart for {column}: {str(e)}")
return None
def visualize_data(df):
"""Automated dashboard with optimized visualizations."""
if df is None or df.empty:
st.error("❌ No data available. Please upload and clean your data first.")
return
# Calculate dataframe hash only once
df_hash = compute_df_hash(df)
# Initialize selected columns in session state if not already present
if "selected_viz_columns" not in st.session_state:
# Initialize with first 4 columns or fewer if df has fewer columns
initial_columns = list(df.columns[:min(4, len(df.columns))])
st.session_state.selected_viz_columns = initial_columns
# Filter out any columns that no longer exist in the dataframe
valid_columns = [col for col in st.session_state.selected_viz_columns if col in df.columns]
# Define a callback function to update selected columns
def on_column_selection_change():
# Store the selected columns in session state
st.session_state.selected_viz_columns = st.session_state.viz_column_selector
# Ensure we stay on the visualization tab (index 2)
st.session_state.current_tab_index = 2
# Use session state for the multiselect with a consistent key and callback
selected_columns = st.multiselect(
"Select columns to visualize",
options=df.columns,
default=valid_columns,
key="viz_column_selector",
on_change=on_column_selection_change
)
# Check if we need to recompute column types and correlation matrix
# This will only happen when:
# 1. We don't have column_types in session_state
# 2. The dataframe hash has changed (new data)
# 3. We're using a user-uploaded dataset for the first time
recompute_needed = (
"column_types" not in st.session_state or
"df_hash" not in st.session_state or
st.session_state.get("df_hash") != df_hash
)
if recompute_needed:
with st.spinner("πŸ”„ Analyzing data structure..."):
# Compute and cache column types
st.session_state.column_types = get_column_types(df)
# Compute and cache correlation matrix
st.session_state.corr_matrix = get_corr_matrix(df)
# Update the dataframe hash
st.session_state.df_hash = df_hash
# Ensure we stay on the visualization tab
st.session_state.current_tab_index = 2
# Reset any test results if the data has changed
if "test_results_calculated" in st.session_state:
st.session_state.test_results_calculated = False
# Clear any previous test metrics to avoid using stale data
for key in ['test_metrics', 'test_y_pred', 'test_y_test', 'test_cm', 'sampling_message']:
if key in st.session_state:
del st.session_state[key]
# Use cached values from session state
column_types = st.session_state.column_types
corr_matrix = st.session_state.corr_matrix
if selected_columns:
# Use a container to wrap all visualizations
viz_container = st.container()
with viz_container:
for idx in range(0, len(selected_columns), 2):
col1, col2 = st.columns(2)
for i, col in enumerate([col1, col2]):
if idx + i < len(selected_columns):
column = selected_columns[idx + i]
with col:
# Use consistent keys for charts based on column name
chart_key = f"plot_{column.replace(' ', '_')}"
# Only create chart if column exists in column_types
if column in column_types:
fig = create_chart(df, column, column_types[column])
if fig:
st.plotly_chart(fig, use_container_width=True, key=chart_key)
with st.expander(f"πŸ“Š Summary Statistics - {column}", expanded=False):
if "NUMERIC" in column_types[column]:
st.dataframe(df[column].describe(), key=f"stats_{column.replace(' ', '_')}")
else:
st.dataframe(df[column].value_counts(), key=f"counts_{column.replace(' ', '_')}")
else:
st.warning(f"⚠️ Column '{column}' not found in the dataset or its type couldn't be determined.")
if corr_matrix is not None:
st.subheader("πŸ”— Correlation Analysis")
fig = px.imshow(corr_matrix, title="Correlation Matrix", color_continuous_scale="RdBu")
st.plotly_chart(fig, use_container_width=True, key="corr_matrix_plot")
else:
st.info("πŸ‘† Please select columns to visualize")