File size: 2,743 Bytes
c689089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# plot_metrics.py

import os
import json
import pandas as pd
import matplotlib.pyplot as plt
from transformer_model.scripts.config_transformer import RESULTS_DIR

# === Plot 1: Training Metrics ===

# Load training metrics
training_metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
with open(training_metrics_path, "r") as f:
    metrics = json.load(f)

train_losses = metrics["train_losses"]
test_mses = metrics["test_mses"]
test_maes = metrics["test_maes"]

plt.figure(figsize=(10, 6))
plt.plot(range(1, len(train_losses) + 1), train_losses, label="Train Loss", color="blue")
plt.plot(range(1, len(test_mses) + 1), test_mses, label="Test MSE", color="red")
plt.plot(range(1, len(test_maes) + 1), test_maes, label="Test MAE", color="green")
plt.xlabel("Epoch")
plt.ylabel("Loss / Metric")
plt.title("Training Loss vs Test Metrics")
plt.legend()
plt.grid(True)

plot_path = os.path.join(RESULTS_DIR, "training_plot.png")
plt.savefig(plot_path)
print(f"[Saved] Training metrics plot: {plot_path}")
plt.show()


# === Plot 2: Predictions vs Ground Truth (Full Range) ===

# Load comparison results
comparison_path = os.path.join(RESULTS_DIR, "test_results.csv")
df_comparison = pd.read_csv(comparison_path, parse_dates=["Timestamp"])

plt.figure(figsize=(15, 6))
plt.plot(df_comparison["Timestamp"], df_comparison["True Consumption (MW)"], label="True", color="darkblue")
plt.plot(df_comparison["Timestamp"], df_comparison["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
plt.title("Energy Consumption: Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()

plot_path = os.path.join(RESULTS_DIR, "comparison_plot_full.png")
plt.savefig(plot_path)
print(f"[Saved] Full range comparison plot: {plot_path}")
plt.show()


# === Plot 3: Predictions vs Ground Truth (First Month) ===

first_month_start = df_comparison["Timestamp"].min()
first_month_end = first_month_start + pd.Timedelta(days=25)
df_first_month = df_comparison[(df_comparison["Timestamp"] >= first_month_start) & (df_comparison["Timestamp"] <= first_month_end)]

plt.figure(figsize=(15, 6))
plt.plot(df_first_month["Timestamp"], df_first_month["True Consumption (MW)"], label="True", color="darkblue")
plt.plot(df_first_month["Timestamp"], df_first_month["Predicted Consumption (MW)"], label="Predicted", color="red", linestyle="--")
plt.title("Energy Consumption (First Month): Predictions vs Ground Truth")
plt.xlabel("Time")
plt.ylabel("Consumption (MW)")
plt.legend()
plt.grid(True)
plt.tight_layout()

plot_path = os.path.join(RESULTS_DIR, "comparison_plot_1month.png")
plt.savefig(plot_path)
print(f"[Saved] 1-Month comparison plot: {plot_path}")
plt.show()