3v324v23's picture
initial
c689089
# plot_metrics.py
import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from transformer_model.scripts.config_transformer import RESULTS_DIR
# === Plot 1: Training Metrics ===
# Load training metrics
training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
with open(training_metrics_path, "r") as f:
metrics = json.load(f)
train_losses = metrics["train_losses"]
test_mses = metrics["test_mses"]
test_maes = metrics["test_maes"]
plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue")
plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
plt.xlabel("Epoch")
plt.ylabel("Loss / Metric")
plt.title("Training Loss vs Test Metrics")
plt.legend()
plt.grid(True)
plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
plt.savefig(plot_path)
print(f"[Saved] Training metrics plot: {plot_path}")
plt.show()
# === Plot 2: Predictions vs Ground Truth (Full Range) ===
# Load comparison results
comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])
plt.figure(figsize=(15, 6))
plt.plot(df_comparison["Timestamp"], df_comparison["True Consumption (MW)"], label="True", color="darkblue")
plt.plot(df_comparison["Timestamp"], df_comparison["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
plt.title("Energy Consumption: Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
plt.savefig(plot_path)
print(f"[Saved] Full range comparison plot: {plot_path}")
plt.show()
# === Plot 3: Predictions vs Ground Truth (First Month) ===
first_month_start = df_comparison["Timestamp"].min()
first_month_end = first_month_start + pd.Timedelta(days=25)
df_first_month = df_comparison[(df_comparison["Timestamp"] >= first_month_start) & (df_comparison["Timestamp"] <= first_month_end)]
plt.figure(figsize=(15, 6))
plt.plot(df_first_month["Timestamp"], df_first_month["True Consumption (MW)"], label="True", color="darkblue")
plt.plot(df_first_month["Timestamp"], df_first_month["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()
plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
plt.savefig(plot_path)
print(f"[Saved] 1-Month comparison plot: {plot_path}")
plt.show()