TimeFlowPro1 / visualization /visualization_manager.py
ArabovMK's picture
Update all files
bd3c428
# ============================================
# CLASS 13: VISUALISATION MANAGER (UPDATED)
# ============================================
import os
from datetime import datetime
import json
from typing import Dict, List, Optional, Tuple, Union, Any
import pandas as pd
import numpy as np
from statsmodels.graphics.tsaplots import plot_acf, plot_pacf
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import gaussian_kde
import matplotlib
matplotlib.use('Agg') # Use non-display backend
from config.config import Config
import logging
# Logging setup
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class VisualisationManager:
"""Class for managing all visualisations"""
def __init__(self, config: Config):
"""
Initialise visualisation manager
Parameters:
-----------
config : Config
Experiment configuration
"""
self.config = config
self.plots_generated = {}
self.plot_files = {}
self.figure_count = 0
# Create directory structure for saving plots
self._create_directory_structure()
def _create_directory_structure(self) -> None:
"""Create directory structure for saving plots"""
base_dir = self.config.results_dir
# Main plot directories
self.plots_dir = os.path.join(base_dir, "plots")
self.correlations_dir = os.path.join(base_dir, "plots", "correlations")
self.distributions_dir = os.path.join(base_dir, "plots", "distributions")
self.features_dir = os.path.join(base_dir, "plots", "features")
self.time_series_dir = os.path.join(base_dir, "plots", "time_series")
self.preprocessing_dir = os.path.join(base_dir, "plots", "preprocessing")
self.summary_dir = os.path.join(base_dir, "plots", "summary")
self.reports_dir = os.path.join(base_dir, "reports")
# Create directories
directories = [
self.plots_dir,
self.correlations_dir,
self.distributions_dir,
self.features_dir,
self.time_series_dir,
self.preprocessing_dir,
self.summary_dir,
self.reports_dir
]
for directory in directories:
os.makedirs(directory, exist_ok=True)
logger.debug(f"Created directory: {directory}")
def _save_figure(self, fig: plt.Figure, filename: str,
subdirectory: str = None, dpi: int = 300) -> str:
"""
Save plot and close it
Parameters:
-----------
fig : matplotlib.figure.Figure
Plot figure object
filename : str
Filename for saving
subdirectory : str, optional
Subdirectory for saving
dpi : int
Save quality
Returns:
--------
str : full path to saved file
"""
if not filename.endswith('.png'):
filename = f"{filename}.png"
if subdirectory:
save_dir = os.path.join(self.plots_dir, subdirectory)
os.makedirs(save_dir, exist_ok=True)
else:
save_dir = self.plots_dir
filepath = os.path.join(save_dir, filename)
try:
fig.savefig(filepath, dpi=dpi, bbox_inches='tight', facecolor='white')
logger.info(f"βœ“ Plot saved: {filepath}")
except Exception as e:
logger.error(f"βœ— Error saving plot {filename}: {e}")
filepath = None
# Close plot without display
plt.close(fig)
return filepath
# ============================================
# MAIN VISUALISATION METHODS
# ============================================
def create_summary_dashboard(
self,
data: pd.DataFrame,
preprocessing_stages: Dict = None,
filename: str = "summary_dashboard"
) -> str:
"""
Create summary visualisation dashboard
Parameters:
-----------
data : pd.DataFrame
Data for visualisation
preprocessing_stages : Dict, optional
Preprocessing stages information
filename : str
Filename for saving
Returns:
--------
str : path to saved file or None if error
"""
logger.info("\n" + "="*80)
logger.info("CREATING SUMMARY DASHBOARD")
logger.info("="*80)
target_col = self.config.target_column
try:
# Create large dashboard
fig = plt.figure(figsize=(20, 24))
gs = fig.add_gridspec(6, 4, hspace=0.3, wspace=0.3)
# 1. Time series of target variable
ax1 = fig.add_subplot(gs[0, :2])
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
ax1.plot(data.index, data[target_col], linewidth=1, color='blue', alpha=0.7)
ax1.set_title(f'Time Series: {target_col}', fontsize=12, fontweight='bold')
ax1.set_xlabel('Date', fontsize=10)
ax1.set_ylabel(target_col, fontsize=10)
ax1.grid(True, alpha=0.3)
ax1.tick_params(axis='x', rotation=45)
else:
ax1.text(0.5, 0.5, 'No time series data available',
ha='center', va='center', transform=ax1.transAxes)
# 2. Target variable distribution
ax2 = fig.add_subplot(gs[0, 2:])
if target_col in data.columns:
values = data[target_col].dropna()
if len(values) > 0:
ax2.hist(values, bins=30, edgecolor='black', alpha=0.7, color='green')
ax2.set_title(f'Distribution: {target_col}', fontsize=12, fontweight='bold')
ax2.set_xlabel(target_col, fontsize=10)
ax2.set_ylabel('Frequency', fontsize=10)
ax2.grid(True, alpha=0.3)
else:
ax2.text(0.5, 0.5, 'No data for distribution',
ha='center', va='center', transform=ax2.transAxes)
# 3. Correlation matrix (top features)
ax3 = fig.add_subplot(gs[1, :])
numeric_cols = data.select_dtypes(include=[np.number]).columns
if len(numeric_cols) > 1:
display_cols = list(numeric_cols[:15])
if target_col not in display_cols and target_col in data.columns:
display_cols = [target_col] + [c for c in display_cols if c != target_col][:14]
corr_matrix = data[display_cols].corr()
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
im = ax3.imshow(corr_matrix.where(~mask), cmap='coolwarm', vmin=-1, vmax=1, aspect='auto')
ax3.set_title('Correlation Matrix (Top 15 Features)',
fontsize=12, fontweight='bold')
ax3.set_xticks(range(len(display_cols)))
ax3.set_yticks(range(len(display_cols)))
ax3.set_xticklabels(display_cols, rotation=90, fontsize=8)
ax3.set_yticklabels(display_cols, fontsize=8)
plt.colorbar(im, ax=ax3, shrink=0.8)
# 4. Seasonal patterns
ax4 = fig.add_subplot(gs[2, :2])
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
data_copy = data.copy()
data_copy['month'] = data_copy.index.month
monthly_avg = data_copy.groupby('month')[target_col].mean()
colors = plt.cm.Set3(np.linspace(0, 1, len(monthly_avg)))
ax4.bar(monthly_avg.index, monthly_avg.values, color=colors, edgecolor='black')
ax4.set_title('Average Values by Month', fontsize=12, fontweight='bold')
ax4.set_xlabel('Month', fontsize=10)
ax4.set_ylabel(f'Average {target_col}', fontsize=10)
ax4.set_xticks(range(1, 13))
month_names = ['Jan', 'Feb', 'Mar', 'Apr', 'May', 'Jun',
'Jul', 'Aug', 'Sep', 'Oct', 'Nov', 'Dec']
ax4.set_xticklabels(month_names)
ax4.grid(True, alpha=0.3, axis='y')
# 5. Weekly patterns
ax5 = fig.add_subplot(gs[2, 2:])
if target_col in data.columns and isinstance(data.index, pd.DatetimeIndex):
data_copy = data.copy()
data_copy['dayofweek'] = data_copy.index.dayofweek
daily_avg = data_copy.groupby('dayofweek')[target_col].mean()
colors = plt.cm.Paired(np.linspace(0, 1, len(daily_avg)))
ax5.bar(daily_avg.index, daily_avg.values, color=colors, edgecolor='black')
ax5.set_title('Average Values by Day of Week', fontsize=12, fontweight='bold')
ax5.set_xlabel('Day of Week', fontsize=10)
ax5.set_ylabel(f'Average {target_col}', fontsize=10)
ax5.set_xticks(range(7))
ax5.set_xticklabels(['Mon', 'Tue', 'Wed', 'Thu', 'Fri', 'Sat', 'Sun'])
ax5.grid(True, alpha=0.3, axis='y')
# 6. Trend and seasonality
ax6 = fig.add_subplot(gs[3, :])
if target_col in data.columns and len(data) > 30:
try:
window_size = min(365, len(data) // 10)
if window_size >= 7:
rolling_mean = data[target_col].rolling(window=window_size, center=True).mean()
rolling_std = data[target_col].rolling(window=window_size, center=True).std()
ax6.plot(data.index, data[target_col], alpha=0.5,
label='Original Series', linewidth=0.5, color='blue')
ax6.plot(rolling_mean.index, rolling_mean,
label=f'Rolling Mean ({window_size} days)',
color='red', linewidth=2)
ax6.fill_between(rolling_mean.index,
rolling_mean - rolling_std,
rolling_mean + rolling_std,
alpha=0.2, color='red')
ax6.set_title('Trend and Volatility', fontsize=12, fontweight='bold')
ax6.set_xlabel('Date', fontsize=10)
ax6.set_ylabel(target_col, fontsize=10)
ax6.legend(fontsize=9, loc='upper left')
ax6.grid(True, alpha=0.3)
else:
ax6.text(0.5, 0.5, 'Insufficient data for trend analysis',
ha='center', va='center', transform=ax6.transAxes)
except Exception as e:
logger.warning(f"Error plotting trend: {e}")
ax6.text(0.5, 0.5, 'Error plotting trend',
ha='center', va='center', transform=ax6.transAxes)
# 7. Preprocessing statistics
if preprocessing_stages:
ax7 = fig.add_subplot(gs[4, :2])
stages = list(preprocessing_stages.keys())
values = list(preprocessing_stages.values())
colors = plt.cm.viridis(np.linspace(0.3, 0.9, len(stages)))
bars = ax7.bar(range(len(stages)), values, color=colors, edgecolor='black')
ax7.set_title('Preprocessing Statistics', fontsize=12, fontweight='bold')
ax7.set_xlabel('Processing Stage', fontsize=10)
ax7.set_ylabel('Value', fontsize=10)
ax7.set_xticks(range(len(stages)))
ax7.set_xticklabels([s[:15] + '...' if len(s) > 15 else s for s in stages],
rotation=45, ha='right', fontsize=9)
ax7.grid(True, alpha=0.3, axis='y')
# Add values on bars
for bar, value in zip(bars, values):
height = bar.get_height()
ax7.text(bar.get_x() + bar.get_width()/2., height,
f'{value:.2f}', ha='center', va='bottom', fontsize=8)
# 8. Data information
ax8 = fig.add_subplot(gs[4, 2:])
ax8.axis('off')
info_text = []
info_text.append("GENERAL CHARACTERISTICS:")
info_text.append(f"β€’ Number of records: {len(data):,}")
info_text.append(f"β€’ Number of features: {len(data.columns)}")
if isinstance(data.index, pd.DatetimeIndex):
info_text.append(f"β€’ Period: {data.index.min().strftime('%Y-%m-%d')} - "
f"{data.index.max().strftime('%Y-%m-%d')}")
info_text.append(f"β€’ Days of data: {(data.index.max() - data.index.min()).days}")
if target_col in data.columns:
target_stats = data[target_col].describe()
info_text.append(f"\nTARGET VARIABLE '{target_col}':")
info_text.append(f"β€’ Mean: {target_stats['mean']:.2f}")
info_text.append(f"β€’ Standard deviation: {target_stats['std']:.2f}")
info_text.append(f"β€’ Minimum: {target_stats['min']:.2f}")
info_text.append(f"β€’ 25%: {target_stats['25%']:.2f}")
info_text.append(f"β€’ 50% (median): {target_stats['50%']:.2f}")
info_text.append(f"β€’ 75%: {target_stats['75%']:.2f}")
info_text.append(f"β€’ Maximum: {target_stats['max']:.2f}")
info_text.append(f"\nDATA TYPES:")
for dtype, count in data.dtypes.value_counts().items():
info_text.append(f"β€’ {dtype}: {count} columns")
missing_info = data.isnull().sum()
missing_total = missing_info.sum()
missing_percent = missing_total / data.size * 100
info_text.append(f"\nMISSING VALUES:")
info_text.append(f"β€’ Total missing: {missing_total:,}")
info_text.append(f"β€’ Missing percentage: {missing_percent:.2f}%")
if missing_total > 0:
top_missing = missing_info.nlargest(5)
info_text.append(f"β€’ Top 5 columns with missing values:")
for col, count in top_missing.items():
percent = count / len(data) * 100
info_text.append(f" {col}: {count} ({percent:.1f}%)")
ax8.text(0.02, 0.98, '\n'.join(info_text), transform=ax8.transAxes,
fontsize=8, verticalalignment='top', fontfamily='monospace')
# 9. Autocorrelation plot
ax9 = fig.add_subplot(gs[5, :2])
if target_col in data.columns:
try:
series = data[target_col].dropna()
if len(series) > 50:
plot_acf(series, lags=min(50, len(series)-1), ax=ax9, alpha=0.05)
ax9.set_title('Autocorrelation Function (ACF)', fontsize=12, fontweight='bold')
ax9.set_xlabel('Lag', fontsize=10)
ax9.set_ylabel('Autocorrelation', fontsize=10)
ax9.grid(True, alpha=0.3)
else:
ax9.text(0.5, 0.5, 'Insufficient data for ACF',
ha='center', va='center', transform=ax9.transAxes)
except Exception as e:
logger.warning(f"Error plotting ACF: {e}")
ax9.text(0.5, 0.5, 'Error calculating ACF',
ha='center', va='center', transform=ax9.transAxes)
# 10. Partial autocorrelation plot
ax10 = fig.add_subplot(gs[5, 2:])
if target_col in data.columns:
try:
series = data[target_col].dropna()
if len(series) > 50:
plot_pacf(series, lags=min(50, len(series)-1), ax=ax10, alpha=0.05)
ax10.set_title('Partial Autocorrelation Function (PACF)',
fontsize=12, fontweight='bold')
ax10.set_xlabel('Lag', fontsize=10)
ax10.set_ylabel('Partial Autocorrelation', fontsize=10)
ax10.grid(True, alpha=0.3)
else:
ax10.text(0.5, 0.5, 'Insufficient data for PACF',
ha='center', va='center', transform=ax10.transAxes)
except Exception as e:
logger.warning(f"Error plotting PACF: {e}")
ax10.text(0.5, 0.5, 'Error calculating PACF',
ha='center', va='center', transform=ax10.transAxes)
plt.suptitle('Data Analysis Summary Dashboard', fontsize=16, fontweight='bold', y=0.98)
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "summary")
self.plot_files['summary_dashboard'] = filepath
return filepath
except Exception as e:
logger.error(f"Error creating summary dashboard: {e}")
return None
# ============================================
# SPECIFIC METHODS FOR SAVING YOUR PLOTS
# ============================================
def save_data_split_plot(self, filename: str = "data_split.png") -> str:
"""
Save data split plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "time_series")
self.plot_files['data_split'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving data_split plot: {e}")
return None
def save_feature_selection_correlation_plot(self, filename: str = "feature_selection_correlation.png") -> str:
"""
Save feature selection correlation plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "correlations")
self.plot_files['feature_selection_correlation'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving feature_selection_correlation plot: {e}")
return None
def save_missing_values_analysis_plot(self, filename: str = "missing_values_analysis.png") -> str:
"""
Save missing values analysis plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "preprocessing")
self.plot_files['missing_values_analysis'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving missing_values_analysis plot: {e}")
return None
def save_outlier_handling_results_plot(self, filename: str = "outlier_handling_results.png") -> str:
"""
Save outlier handling results plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "preprocessing")
self.plot_files['outlier_handling_results'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving outlier_handling_results plot: {e}")
return None
def save_outliers_analysis_plot(self, filename: str = "outliers_analysis.png") -> str:
"""
Save outliers analysis plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "preprocessing")
self.plot_files['outliers_analysis'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving outliers_analysis plot: {e}")
return None
def save_scaling_results_plot(self, filename: str = "scaling_results.png") -> str:
"""
Save scaling results plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "preprocessing")
self.plot_files['scaling_results'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving scaling_results plot: {e}")
return None
def save_stationarity_analysis_plot(self, filename: str = "stationarity_analysis.png") -> str:
"""
Save stationarity analysis plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "time_series")
self.plot_files['stationarity_analysis'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving stationarity_analysis plot: {e}")
return None
def save_temporal_outliers_plot(self, filename: str = "temporal_outliers.png") -> str:
"""
Save temporal outliers plot
Parameters:
-----------
filename : str
Filename for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, "time_series")
self.plot_files['temporal_outliers'] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving temporal_outliers plot: {e}")
return None
# ============================================
# UNIVERSAL METHOD FOR SAVING ANY PLOT
# ============================================
def save_current_plot(self, filename: str, subdirectory: str = None) -> str:
"""
Universal method for saving current plot
Parameters:
-----------
filename : str
Filename for saving
subdirectory : str, optional
Subdirectory for saving
Returns:
--------
str : path to saved file
"""
try:
fig = plt.gcf() # Get current figure
filepath = self._save_figure(fig, filename, subdirectory)
# Save plot information
plot_key = filename.replace('.png', '').replace('.jpg', '')
self.plot_files[plot_key] = filepath
return filepath
except Exception as e:
logger.error(f"Error saving plot {filename}: {e}")
return None
# ============================================
# ADDITIONAL VISUALISATION METHODS
# ============================================
def create_feature_importance_plot(
self,
feature_importance: Dict,
top_n: int = 20,
filename: str = "feature_importance"
) -> str:
"""
Create feature importance plot
Parameters:
-----------
feature_importance : Dict
Dictionary with feature importance
top_n : int
Number of top features to display
filename : str
Filename for saving
Returns:
--------
str : path to saved file or None if error
"""
if not feature_importance:
logger.warning("No feature importance data for visualisation")
return None
try:
# Convert to Series and sort
importance_series = pd.Series(feature_importance).sort_values(ascending=False)
top_features = importance_series.head(top_n)
# Create plot
fig, ax = plt.subplots(figsize=(12, 8))
y_pos = np.arange(len(top_features))
colors = plt.cm.plasma(np.linspace(0.2, 0.9, len(top_features)))
bars = ax.barh(y_pos, top_features.values, color=colors, edgecolor='black')
ax.set_yticks(y_pos)
ax.set_yticklabels(top_features.index, fontsize=10)
ax.invert_yaxis()
ax.set_xlabel('Feature Importance', fontsize=11, fontweight='bold')
ax.set_title(f'Top-{top_n} Most Important Features', fontsize=14, fontweight='bold')
ax.grid(True, alpha=0.3, axis='x')
# Add values on bars
for i, (bar, value) in enumerate(zip(bars, top_features.values)):
width = bar.get_width()
ax.text(width * 1.01, bar.get_y() + bar.get_height()/2,
f'{value:.4f}', va='center', fontsize=9, fontweight='bold')
# Add additional information
plt.text(0.02, 0.98, f'Total features: {len(importance_series)}',
transform=fig.transFigure, fontsize=9, verticalalignment='top')
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "features")
self.plot_files['feature_importance'] = filepath
return filepath
except Exception as e:
logger.error(f"Error creating feature importance plot: {e}")
return None
def create_correlation_heatmap(
self,
data: pd.DataFrame,
top_n: int = 20,
filename: str = "correlation_heatmap"
) -> Tuple[str, Optional[str]]:
"""
Create correlation heatmap
Parameters:
-----------
data : pd.DataFrame
Data for analysis
top_n : int
Number of top features to display
filename : str
Filename for saving
Returns:
--------
Tuple[str, Optional[str]]:
(path to main heatmap, path to target correlation heatmap)
"""
target_col = self.config.target_column
try:
numeric_cols = data.select_dtypes(include=[np.number]).columns.tolist()
if len(numeric_cols) < 2:
logger.warning("Insufficient numeric features for correlation analysis")
return None, None
# Create two heatmaps
# 1. Main correlation heatmap between all features
main_filepath = self._create_main_correlation_heatmap(data, numeric_cols, top_n, filename)
# 2. Target correlation heatmap
target_filepath = None
if target_col in data.columns and target_col in numeric_cols:
target_filepath = self._create_target_correlation_heatmap(data, target_col, numeric_cols, filename)
return main_filepath, target_filepath
except Exception as e:
logger.error(f"Error creating correlation heatmap: {e}")
return None, None
def _create_main_correlation_heatmap(
self,
data: pd.DataFrame,
numeric_cols: List[str],
top_n: int,
filename: str
) -> str:
"""Create main correlation heatmap"""
# Limit number of features for better readability
if len(numeric_cols) > top_n:
# Select features with highest variance
variances = data[numeric_cols].var().sort_values(ascending=False)
selected_cols = variances.head(top_n).index.tolist()
else:
selected_cols = numeric_cols
# Calculate correlation
corr_matrix = data[selected_cols].corr()
fig, ax = plt.subplots(figsize=(14, 12))
# Mask for upper triangle
mask = np.triu(np.ones_like(corr_matrix, dtype=bool))
# Create heatmap
sns.heatmap(
corr_matrix,
annot=True,
fmt='.2f',
cmap='coolwarm',
center=0,
square=True,
mask=mask,
cbar_kws={'shrink': 0.8, 'label': 'Correlation Coefficient'},
linewidths=0.5,
linecolor='white',
ax=ax,
annot_kws={'size': 8}
)
ax.set_title(f'Correlation Matrix Between Features (Top-{top_n})',
fontsize=14, fontweight='bold', pad=20)
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "correlations")
self.plot_files['correlation_heatmap_main'] = filepath
return filepath
def _create_target_correlation_heatmap(
self,
data: pd.DataFrame,
target_col: str,
numeric_cols: List[str],
filename: str
) -> str:
"""Create target correlation heatmap"""
# Calculate correlations with target variable
correlations = data[numeric_cols].corrwith(data[target_col]).sort_values(key=abs, ascending=False)
# Exclude target variable itself
correlations = correlations[correlations.index != target_col]
# Take top 15 features
top_features = correlations.head(15)
fig, ax = plt.subplots(figsize=(10, 8))
colors = ['red' if x < 0 else 'green' for x in top_features.values]
bars = ax.barh(range(len(top_features)), top_features.values, color=colors, edgecolor='black')
ax.set_yticks(range(len(top_features)))
ax.set_yticklabels(top_features.index, fontsize=10)
ax.invert_yaxis()
ax.set_xlabel('Correlation Coefficient', fontsize=11, fontweight='bold')
ax.set_title(f'Feature Correlations with Target Variable "{target_col}"',
fontsize=14, fontweight='bold', pad=20)
ax.grid(True, alpha=0.3, axis='x')
ax.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
# Add values on bars
for bar, value in zip(bars, top_features.values):
width = bar.get_width()
ax.text(width + (0.01 if width >= 0 else -0.04),
bar.get_y() + bar.get_height()/2,
f'{value:.3f}',
va='center',
ha='left' if width >= 0 else 'right',
fontsize=9,
fontweight='bold',
color='black')
plt.tight_layout()
# Save
target_filename = f"{filename}_with_target"
filepath = self._save_figure(fig, target_filename, "correlations")
self.plot_files['correlation_with_target'] = filepath
return filepath
def create_distribution_comparison(
self,
original_data: pd.DataFrame,
processed_data: pd.DataFrame,
columns: List[str] = None,
max_columns: int = 12,
filename: str = "distribution_comparison"
) -> str:
"""
Compare distributions before and after processing
Parameters:
-----------
original_data : pd.DataFrame
Original data
processed_data : pd.DataFrame
Processed data
columns : List[str], optional
List of columns to compare
max_columns : int
Maximum number of columns to display
filename : str
Filename for saving
Returns:
--------
str : path to saved file or None if error
"""
try:
if columns is None:
# Select numeric columns common to both datasets
numeric_cols_original = original_data.select_dtypes(include=[np.number]).columns
numeric_cols_processed = processed_data.select_dtypes(include=[np.number]).columns
common_cols = list(set(numeric_cols_original) & set(numeric_cols_processed))
# Sort by variance in original data
variances = original_data[common_cols].var().sort_values(ascending=False)
columns = variances.head(max_columns).index.tolist()
n_cols = min(4, len(columns))
n_rows = (len(columns) + n_cols - 1) // n_cols
fig, axes = plt.subplots(n_rows, n_cols, figsize=(n_cols * 4, n_rows * 3.5))
fig.suptitle('Distribution Comparison Before and After Processing',
fontsize=16, fontweight='bold', y=0.98)
if n_rows == 1 and n_cols == 1:
axes = np.array([axes])
axes = axes.flat if hasattr(axes, 'flat') else [axes]
for idx, col in enumerate(columns):
if idx >= len(axes):
break
ax = axes[idx]
if col in original_data.columns and col in processed_data.columns:
original_values = original_data[col].dropna()
processed_values = processed_data[col].dropna()
if len(original_values) > 0 and len(processed_values) > 0:
# Use common bins for comparison
all_values = pd.concat([original_values, processed_values])
bins = np.histogram_bin_edges(all_values, bins=30)
# Histograms
ax.hist(original_values, bins=bins, alpha=0.5,
label='Before Processing', density=True, color='blue')
ax.hist(processed_values, bins=bins, alpha=0.5,
label='After Processing', density=True, color='orange')
# Add KDE
try:
if len(original_values) > 10:
kde_original = gaussian_kde(original_values)
x_range = np.linspace(original_values.min(), original_values.max(), 100)
ax.plot(x_range, kde_original(x_range), 'b-', linewidth=1.5, alpha=0.8)
if len(processed_values) > 10:
kde_processed = gaussian_kde(processed_values)
x_range = np.linspace(processed_values.min(), processed_values.max(), 100)
ax.plot(x_range, kde_processed(x_range), 'orange', linewidth=1.5, alpha=0.8)
except:
pass
# Add statistics
stats_text = []
if len(original_values) > 0:
stats_text.append(f"Before: ΞΌ={original_values.mean():.2f}, Οƒ={original_values.std():.2f}")
if len(processed_values) > 0:
stats_text.append(f"After: ΞΌ={processed_values.mean():.2f}, Οƒ={processed_values.std():.2f}")
if stats_text:
ax.text(0.02, 0.98, '\n'.join(stats_text),
transform=ax.transAxes, fontsize=8,
verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
ax.set_title(f'{col}', fontsize=11, fontweight='bold')
ax.set_xlabel('Value', fontsize=9)
ax.set_ylabel('Density', fontsize=9)
ax.legend(fontsize=8)
ax.grid(True, alpha=0.3)
else:
ax.text(0.5, 0.5, 'No data',
ha='center', va='center', transform=ax.transAxes)
else:
ax.text(0.5, 0.5, 'Column not found',
ha='center', va='center', transform=ax.transAxes)
# Hide unused subplots
for idx in range(len(columns), len(axes)):
axes[idx].set_visible(False)
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "distributions")
self.plot_files['distribution_comparison'] = filepath
return filepath
except Exception as e:
logger.error(f"Error creating distribution comparison: {e}")
return None
def create_time_series_decomposition_plot(
self,
decomposition_result: Dict,
filename: str = "time_series_decomposition"
) -> str:
"""
Visualise time series decomposition
Parameters:
-----------
decomposition_result : Dict
Decomposition results
filename : str
Filename for saving
Returns:
--------
str : path to saved file or None if error
"""
target_col = self.config.target_column
try:
fig, axes = plt.subplots(4, 1, figsize=(14, 10))
fig.suptitle(f'Time Series Decomposition: {target_col}',
fontsize=16, fontweight='bold', y=0.98)
# Original series
if 'observed' in decomposition_result:
observed = decomposition_result['observed']
axes[0].plot(observed, color='blue', linewidth=1.5)
axes[0].set_ylabel('Observed', fontsize=11, fontweight='bold')
axes[0].grid(True, alpha=0.3)
axes[0].set_title('Original Time Series', fontsize=12)
# Trend
if 'trend' in decomposition_result and decomposition_result['trend'] is not None:
trend = decomposition_result['trend']
axes[1].plot(trend, color='red', linewidth=2)
axes[1].set_ylabel('Trend', fontsize=11, fontweight='bold')
axes[1].grid(True, alpha=0.3)
axes[1].set_title('Trend Component', fontsize=12)
# Seasonality
if 'seasonal' in decomposition_result and decomposition_result['seasonal'] is not None:
seasonal = decomposition_result['seasonal']
axes[2].plot(seasonal, color='green', linewidth=1.5)
axes[2].set_ylabel('Seasonal', fontsize=11, fontweight='bold')
axes[2].grid(True, alpha=0.3)
axes[2].set_title('Seasonal Component', fontsize=12)
# Residuals
if 'residual' in decomposition_result and decomposition_result['residual'] is not None:
residual = decomposition_result['residual']
axes[3].plot(residual, color='purple', linewidth=1, alpha=0.7)
axes[3].set_ylabel('Residuals', fontsize=11, fontweight='bold')
axes[3].set_xlabel('Date', fontsize=11, fontweight='bold')
axes[3].grid(True, alpha=0.3)
axes[3].set_title('Residual Component', fontsize=12)
# Add residual statistics
if len(residual) > 0:
stats_text = (f"Mean: {residual.mean():.4f}\n"
f"Std: {residual.std():.4f}\n"
f"Min: {residual.min():.4f}\n"
f"Max: {residual.max():.4f}")
axes[3].text(0.02, 0.98, stats_text, transform=axes[3].transAxes,
fontsize=8, verticalalignment='top',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "time_series")
self.plot_files['time_series_decomposition'] = filepath
return filepath
except Exception as e:
logger.error(f"Error creating time series decomposition: {e}")
return None
def create_data_quality_report(
self,
validation_results: Dict,
filename: str = "data_quality_report"
) -> str:
"""
Create visual data quality report
Parameters:
-----------
validation_results : Dict
Validation results
filename : str
Filename for saving
Returns:
--------
str : path to saved file or None if error
"""
try:
fig = plt.figure(figsize=(16, 12))
fig.suptitle('Data Quality Report', fontsize=18, fontweight='bold', y=0.98)
# Use GridSpec for more complex layout
gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
# 1. Quality radar chart (top left)
ax1 = fig.add_subplot(gs[0, 0], projection='polar')
categories = ['Size', 'Missing', 'Duplicates', 'Stability', 'Informativeness']
# Extract values from validation results
if 'quality_metrics' in validation_results:
values = [
validation_results['quality_metrics'].get('size_score', 0.5),
validation_results['quality_metrics'].get('missing_score', 0.5),
validation_results['quality_metrics'].get('duplicates_score', 0.5),
validation_results['quality_metrics'].get('stability_score', 0.5),
validation_results['quality_metrics'].get('informativeness_score', 0.5)
]
else:
values = [0.8, 0.7, 0.9, 0.6, 0.8]
N = len(categories)
angles = [n / float(N) * 2 * np.pi for n in range(N)]
angles += angles[:1]
values += values[:1]
ax1.plot(angles, values, 'o-', linewidth=2, color='blue')
ax1.fill(angles, values, alpha=0.25, color='blue')
ax1.set_xticks(angles[:-1])
ax1.set_xticklabels(categories, fontsize=10)
ax1.set_ylim(0, 1)
ax1.set_title('Data Quality Radar Chart', fontsize=12, fontweight='bold')
ax1.grid(True)
# 2. Check status (top right)
ax2 = fig.add_subplot(gs[0, 1])
basic_checks = validation_results.get('basic_checks', {})
checks_passed = sum(1 for check in basic_checks.values() if check.get('passed', False))
checks_total = len(basic_checks)
checks_failed = checks_total - checks_passed
if checks_total > 0:
colors = ['#4CAF50' if checks_passed > 0 else '#FF6B6B',
'#FF6B6B' if checks_failed > 0 else '#4CAF50']
bars = ax2.bar(['Passed', 'Failed'],
[checks_passed, checks_failed],
color=colors, edgecolor='black')
ax2.set_title(f'Basic Checks: {checks_passed}/{checks_total}',
fontsize=12, fontweight='bold')
ax2.set_ylabel('Number of Checks', fontsize=10)
ax2.grid(True, alpha=0.3, axis='y')
# Add values on bars
for bar, value in zip(bars, [checks_passed, checks_failed]):
height = bar.get_height()
ax2.text(bar.get_x() + bar.get_width()/2., height,
f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
else:
ax2.text(0.5, 0.5, 'No check data available',
ha='center', va='center', transform=ax2.transAxes)
ax2.set_title('Basic Checks', fontsize=12, fontweight='bold')
# 3. Overall score (top right)
ax3 = fig.add_subplot(gs[0, 2])
overall_score = validation_results.get('overall_score', 0)
status = validation_results.get('status', 'UNKNOWN')
# Score pie chart
sizes = [overall_score, 100 - overall_score]
if overall_score >= 80:
colors = ['#4CAF50', '#E0E0E0'] # Green
elif overall_score >= 60:
colors = ['#FFC107', '#E0E0E0'] # Yellow
else:
colors = ['#F44336', '#E0E0E0'] # Red
wedges, texts, autotexts = ax3.pie(sizes, colors=colors, startangle=90,
autopct='%1.1f%%', pctdistance=0.85)
# Central text
status_colors = {'PASS': '#4CAF50', 'WARNING': '#FFC107', 'FAIL': '#F44336'}
status_color = status_colors.get(status, '#757575')
ax3.text(0, 0, f'{overall_score}/100\n{status}',
ha='center', va='center', fontsize=14, fontweight='bold',
color=status_color)
ax3.set_title('Overall Quality Score', fontsize=12, fontweight='bold')
# 4. Issue distribution by type (left middle)
ax4 = fig.add_subplot(gs[1, 0])
issues = validation_results.get('issues', {})
issue_counts = {
'Critical': len(issues.get('critical', [])),
'Warnings': len(issues.get('warning', [])),
'Informational': len(issues.get('info', []))
}
if any(issue_counts.values()):
colors = ['#F44336', '#FF9800', '#2196F3']
bars = ax4.bar(issue_counts.keys(), issue_counts.values(),
color=colors, edgecolor='black')
ax4.set_title('Data Issues by Type', fontsize=12, fontweight='bold')
ax4.set_ylabel('Number of Issues', fontsize=10)
ax4.tick_params(axis='x', rotation=45)
ax4.grid(True, alpha=0.3, axis='y')
# Add values on bars
for bar, value in zip(bars, issue_counts.values()):
height = bar.get_height()
ax4.text(bar.get_x() + bar.get_width()/2., height,
f'{value}', ha='center', va='bottom', fontsize=10, fontweight='bold')
else:
ax4.text(0.5, 0.5, 'No issues detected',
ha='center', va='center', transform=ax4.transAxes, fontsize=12)
ax4.set_title('Data Issues', fontsize=12, fontweight='bold')
# 5. Detailed information (remaining cells)
ax5 = fig.add_subplot(gs[1:, 1:])
ax5.axis('off')
# Form text report
report_text = []
report_text.append("DETAILED REPORT:")
report_text.append("=" * 40)
# Basic information
report_text.append("\nBASIC INFORMATION:")
report_text.append(f"β€’ Overall score: {overall_score}/100")
report_text.append(f"β€’ Status: {status}")
report_text.append(f"β€’ Checks passed: {checks_passed}/{checks_total}")
# Check details
if basic_checks:
report_text.append("\nCHECK DETAILS:")
for check_name, check_result in basic_checks.items():
status_icon = "βœ“" if check_result.get('passed', False) else "βœ—"
report_text.append(f"β€’ {status_icon} {check_name}: {check_result.get('message', '')}")
# Issues
if any(issue_counts.values()):
report_text.append("\nDETECTED ISSUES:")
if issue_counts['Critical'] > 0:
report_text.append("\nCRITICAL:")
for issue in issues.get('critical', []):
report_text.append(f" β€’ {issue}")
if issue_counts['Warnings'] > 0:
report_text.append("\nWARNINGS:")
for issue in issues.get('warning', []):
report_text.append(f" β€’ {issue}")
if issue_counts['Informational'] > 0:
report_text.append("\nINFORMATIONAL:")
for issue in issues.get('info', []):
report_text.append(f" β€’ {issue}")
# Recommendations
recommendations = validation_results.get('recommendations', [])
if recommendations:
report_text.append("\nRECOMMENDATIONS:")
for i, rec in enumerate(recommendations, 1):
report_text.append(f"{i}. {rec}")
ax5.text(0.02, 0.98, '\n'.join(report_text), transform=ax5.transAxes,
fontsize=9, verticalalignment='top', fontfamily='monospace',
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.1))
plt.tight_layout()
# Save
filepath = self._save_figure(fig, filename, "reports")
self.plot_files['data_quality_report'] = filepath
return filepath
except Exception as e:
logger.error(f"Error creating data quality report: {e}")
return None
# ============================================
# METHODS FOR BATCH SAVING
# ============================================
def save_all_preprocessing_plots(self) -> Dict[str, str]:
"""
Save all preprocessing plots from current session
Returns:
--------
Dict[str, str] : dictionary with paths to saved plots
"""
logger.info("Saving all preprocessing plots...")
plots_saved = {}
# Get all open figures
figure_numbers = plt.get_fignums()
if not figure_numbers:
logger.warning("No open plots to save")
return plots_saved
# Save each plot
for fig_num in figure_numbers:
fig = plt.figure(fig_num)
filename = f"preprocessing_plot_{fig_num}.png"
filepath = self._save_figure(fig, filename, "preprocessing")
if filepath:
plots_saved[f"plot_{fig_num}"] = filepath
logger.info(f"Saved {len(plots_saved)} preprocessing plots")
return plots_saved
def create_all_visualizations(
self,
data: pd.DataFrame,
processed_data: pd.DataFrame = None,
feature_importance: Dict = None,
decomposition_result: Dict = None,
validation_results: Dict = None,
preprocessing_stages: Dict = None
) -> Dict[str, str]:
"""
Create all visualisations in one call
Parameters:
-----------
data : pd.DataFrame
Original data
processed_data : pd.DataFrame, optional
Processed data
feature_importance : Dict, optional
Feature importance
decomposition_result : Dict, optional
Decomposition results
validation_results : Dict, optional
Validation results
preprocessing_stages : Dict, optional
Preprocessing stages
Returns:
--------
Dict[str, str] : dictionary with paths to created plots
"""
logger.info("\n" + "="*80)
logger.info("STARTING ALL VISUALISATIONS CREATION")
logger.info("="*80)
result_files = {}
# 1. Summary dashboard
if data is not None:
logger.info("Creating summary dashboard...")
summary_path = self.create_summary_dashboard(data, preprocessing_stages)
if summary_path:
result_files['summary'] = summary_path
# 2. Correlation heatmaps
if data is not None:
logger.info("Creating correlation heatmaps...")
main_corr, target_corr = self.create_correlation_heatmap(data)
if main_corr:
result_files['correlation_main'] = main_corr
if target_corr:
result_files['correlation_target'] = target_corr
# 3. Distribution comparison
if data is not None and processed_data is not None:
logger.info("Creating distribution comparison...")
dist_path = self.create_distribution_comparison(data, processed_data)
if dist_path:
result_files['distribution'] = dist_path
# 4. Feature importance
if feature_importance:
logger.info("Creating feature importance plot...")
feat_path = self.create_feature_importance_plot(feature_importance)
if feat_path:
result_files['feature_importance'] = feat_path
# 5. Time series decomposition
if decomposition_result:
logger.info("Creating time series decomposition...")
decomp_path = self.create_time_series_decomposition_plot(decomposition_result)
if decomp_path:
result_files['decomposition'] = decomp_path
# 6. Data quality report
if validation_results:
logger.info("Creating data quality report...")
quality_path = self.create_data_quality_report(validation_results)
if quality_path:
result_files['quality_report'] = quality_path
# Save information about all plots
self.save_plots_info()
logger.info("\n" + "="*80)
logger.info("VISUALISATIONS SUCCESSFULLY CREATED")
logger.info("="*80)
for plot_name, plot_path in result_files.items():
if plot_path:
logger.info(f"βœ“ {plot_name}: {plot_path}")
return result_files
def get_all_plots(self) -> Dict:
"""Get information about all created plots"""
return self.plot_files
def save_plots_info(self, filename: str = "plots_info.json") -> None:
"""Save plot information to JSON file"""
try:
plots_info = {
'total_plots': len(self.plot_files),
'plots': self.plot_files,
'directories': {
'correlations': self.correlations_dir,
'distributions': self.distributions_dir,
'features': self.features_dir,
'time_series': self.time_series_dir,
'preprocessing': self.preprocessing_dir,
'summary': self.summary_dir,
'reports': self.reports_dir
},
'generation_time': datetime.now().strftime('%Y-%m-%d %H:%M:%S'),
'config': {
'target_column': self.config.target_column,
'results_dir': self.config.results_dir
}
}
filepath = os.path.join(self.reports_dir, filename)
with open(filepath, 'w', encoding='utf-8') as f:
json.dump(plots_info, f, indent=4, ensure_ascii=False, default=str)
logger.info(f"βœ“ Plot information saved: {filepath}")
except Exception as e:
logger.error(f"βœ— Error saving plot information: {e}")
def move_existing_plots(self, source_dir: str = None) -> Dict[str, str]:
"""
Move existing plots from specified directory to structured folders
Parameters:
-----------
source_dir : str, optional
Directory with existing plots
Returns:
--------
Dict[str, str] : dictionary with information about moved files
"""
if source_dir is None:
source_dir = self.plots_dir
if not os.path.exists(source_dir):
logger.warning(f"Source directory doesn't exist: {source_dir}")
return {}
# File to folder mapping
file_to_folder_map = {
# Time series
'data_split.png': 'time_series',
'stationarity_raskhodvoda.png': 'time_series',
'stationarity_analysis.png': 'time_series',
'temporal_outliers.png': 'time_series',
# Correlations
'feature_selection_correlation.png': 'correlations',
# Preprocessing
'missing_values_analysis.png': 'preprocessing',
'outlier_handling_results.png': 'preprocessing',
'outliers_analysis.png': 'preprocessing',
'scaling_results.png': 'preprocessing',
# Default
'default': 'summary'
}
moved_files = {}
for filename in os.listdir(source_dir):
if filename.endswith('.png'):
source_path = os.path.join(source_dir, filename)
# Determine destination folder
target_folder = file_to_folder_map.get(filename, file_to_folder_map['default'])
target_dir = os.path.join(self.plots_dir, target_folder)
# Create destination folder if doesn't exist
os.makedirs(target_dir, exist_ok=True)
# Target path
target_path = os.path.join(target_dir, filename)
try:
# Move file
os.rename(source_path, target_path)
moved_files[filename] = target_path
logger.info(f"Moved: {filename} -> {target_folder}/")
except Exception as e:
logger.error(f"Error moving {filename}: {e}")
logger.info(f"Moved {len(moved_files)} files")
return moved_files