# plot_metrics.py import os import json import pandas as pd import matplotlib.pyplot as plt from transformer_model.scripts.config_transformer import RESULTS_DIR # === Plot 1: Training Metrics === # Load training metrics training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json") with open(training_metrics_path, "r") as f: metrics = json.load(f) train_losses = metrics["train_losses"] test_mses = metrics["test_mses"] test_maes = metrics["test_maes"] plt.figure(figsize=(10, 6)) plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue") plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red") plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green") plt.xlabel("Epoch") plt.ylabel("Loss / Metric") plt.title("Training Loss vs Test Metrics") plt.legend() plt.grid(True) plot_path = os.path.join(RESULTS_DIR, "training_plot.png") plt.savefig(plot_path) print(f"[Saved] Training metrics plot: {plot_path}") plt.show() # === Plot 2: Predictions vs Ground Truth (Full Range) === # Load comparison results comparison_path = os.path.join(RESULTS_DIR, "test_results.csv") df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"]) plt.figure(figsize=(15, 6)) plt.plot(df_comparison["Timestamp"], df_comparison["True Consumption (MW)"], label="True", color="darkblue") plt.plot(df_comparison["Timestamp"], df_comparison["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--") plt.title("Energy Consumption: Predictions vs Ground Truth") plt.xlabel("Time") plt.ylabel("Consumption (MW)") plt.legend() plt.grid(True) plt.tight_layout() plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png") plt.savefig(plot_path) print(f"[Saved] Full range comparison plot: {plot_path}") plt.show() # === Plot 3: Predictions vs Ground Truth (First Month) === first_month_start = df_comparison["Timestamp"].min() first_month_end = first_month_start + pd.Timedelta(days=25) df_first_month = df_comparison[(df_comparison["Timestamp"] >= first_month_start) & (df_comparison["Timestamp"] <= first_month_end)] plt.figure(figsize=(15, 6)) plt.plot(df_first_month["Timestamp"], df_first_month["True Consumption (MW)"], label="True", color="darkblue") plt.plot(df_first_month["Timestamp"], df_first_month["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--") plt.title("Energy Consumption (First Month): Predictions vs Ground Truth") plt.xlabel("Time") plt.ylabel("Consumption (MW)") plt.legend() plt.grid(True) plt.tight_layout() plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png") plt.savefig(plot_path) print(f"[Saved] 1-Month comparison plot: {plot_path}") plt.show()