File size: 6,584 Bytes
c689089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
# train.py

import os
import json
import time
import logging
import numpy as np
import torch
from tqdm import tqdm
from sklearn.metrics import mean_squared_error, mean_absolute_error

from transformer_model.scripts.config_transformer import (
    BASE_DIR,
    MAX_EPOCHS,
    BATCH_SIZE,
    LEARNING_RATE,
    MAX_LR,
    GRAD_CLIP,
    FORECAST_HORIZON,
    CHECKPOINT_DIR,
    RESULTS_DIR
)

from transformer_model.scripts.training.load_basis_model import load_moment_model
from transformer_model.scripts.utils.create_dataloaders import create_dataloaders
from transformer_model.scripts.utils.check_device import check_device
from momentfm.utils.utils import control_randomness


# === Setup logging ===
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s")


def train():
    # Start timing
    start_time = time.time()

    # Setup device (CUDA / DirectML / CPU) and AMP scaler
    device, backend, scaler = check_device()

    # Load base model
    model = load_moment_model().to(device)

    # Set random seeds for reproducibility
    control_randomness(seed=13)

    # Setup loss function and optimizer
    criterion = torch.nn.MSELoss().to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=LEARNING_RATE)

    # Load data
    train_loader, test_loader = create_dataloaders()

    # Setup learning rate scheduler (OneCycle policy)
    total_steps = len(train_loader) * MAX_EPOCHS
    scheduler = torch.optim.lr_scheduler.OneCycleLR(
        optimizer,
        max_lr=MAX_LR,
        total_steps=total_steps,
        pct_start=0.3
    )

    # Ensure output folders exist
    os.makedirs(CHECKPOINT_DIR, exist_ok=True)
    os.makedirs(RESULTS_DIR, exist_ok=True)

    # Store metrics
    train_losses, test_mses, test_maes = [], [], []

    best_mae = float("inf")
    best_epoch = None
    no_improve_epochs = 0
    patience = 5  

    for epoch in range(MAX_EPOCHS):
        model.train()
        epoch_losses = []

        for timeseries, forecast, input_mask in tqdm(train_loader, desc=f"Epoch {epoch}"):
            timeseries = timeseries.float().to(device)
            input_mask = input_mask.to(device)
            forecast = forecast.float().to(device)

            # Zero gradients
            optimizer.zero_grad(set_to_none=True)

            # Forward pass (with AMP if enabled)
            if scaler:
                with torch.amp.autocast(device_type="cuda"):
                    output = model(x_enc=timeseries, input_mask=input_mask)
                    loss = criterion(output.forecast, forecast)
            else:
                output = model(x_enc=timeseries, input_mask=input_mask)
                loss = criterion(output.forecast, forecast)

            # Backward pass + optimization
            if scaler:
                scaler.scale(loss).backward()
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                scaler.step(optimizer)
                scaler.update()
            else:
                loss.backward()
                torch.nn.utils.clip_grad_norm_(model.parameters(), GRAD_CLIP)
                optimizer.step()

            epoch_losses.append(loss.item())

        average_train_loss = np.mean(epoch_losses)
        train_losses.append(average_train_loss)
        logging.info(f"Epoch {epoch}: Train Loss = {average_train_loss:.4f}")

        # === Evaluation ===
        model.eval()
        trues, preds = [], []

        with torch.no_grad():
            for timeseries, forecast, input_mask in test_loader:
                timeseries = timeseries.float().to(device)
                input_mask = input_mask.to(device)
                forecast = forecast.float().to(device)

                if scaler:
                    with torch.amp.autocast(device_type="cuda"):
                        output = model(x_enc=timeseries, input_mask=input_mask)
                else:
                    output = model(x_enc=timeseries, input_mask=input_mask)

                trues.append(forecast.detach().cpu().numpy())
                preds.append(output.forecast.detach().cpu().numpy())

        trues = np.concatenate(trues, axis=0)
        preds = np.concatenate(preds, axis=0)


        # Reshape for sklearn metrics
        trues_2d = trues.reshape(trues.shape[0], -1)
        preds_2d = preds.reshape(preds.shape[0], -1)

        mse = mean_squared_error(trues_2d, preds_2d)
        mae = mean_absolute_error(trues_2d, preds_2d)

        test_mses.append(mse)
        test_maes.append(mae)
        logging.info(f"Epoch {epoch}: Test MSE = {mse:.4f}, MAE = {mae:.4f}")

        # === Early Stopping Check ===
        if mae < best_mae:
            best_mae = mae
            best_epoch = epoch
            no_improve_epochs = 0

            # Save best model
            best_model_path = os.path.join(CHECKPOINT_DIR, "best_model.pth")
            torch.save(model.state_dict(), best_model_path)
            logging.info(f"New best model saved to: {best_model_path} (MAE: {best_mae:.4f})")
        else:
            no_improve_epochs += 1
            logging.info(f"No improvement in MAE for {no_improve_epochs} epoch(s).")

            if no_improve_epochs >= patience:
                logging.info("Early stopping triggered.")
                break

        # Save checkpoint
        checkpoint_path = os.path.join(CHECKPOINT_DIR, f"model_epoch_{epoch}.pth")
        torch.save(model.state_dict(), checkpoint_path)

        scheduler.step()

    logging.info(f"Best model was at epoch {best_epoch} with MAE: {best_mae:.4f}")

    # Save final model
    final_model_path = os.path.join(CHECKPOINT_DIR, "model_final.pth")
    torch.save(model.state_dict(), final_model_path)
    logging.info(f"Final model saved to: {final_model_path}")
    logging.info(f"Final Test MSE: {test_mses[-1]:.4f}, MAE: {test_maes[-1]:.4f}")

    # Save training metrics
    metrics = {
        "train_losses": [float(x) for x in train_losses],
        "test_mses": [float(x) for x in test_mses],
        "test_maes": [float(x) for x in test_maes]
    }

    metrics_path = os.path.join(RESULTS_DIR, "training_metrics.json")
    with open(metrics_path, "w") as f:
        json.dump(metrics, f)
    logging.info(f"Training metrics saved to: {metrics_path}")

    # Done
    elapsed = time.time() - start_time
    logging.info(f"Training complete in {elapsed / 60:.2f} minutes.")


# === Entry Point ===
if __name__ == "__main__":
    try:
        train()
    except Exception as e:
        logging.error(f"Training failed: {e}")