|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import tqdm |
|
|
from .setup import Setup, HookMonitor |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def train_step( |
|
|
|
|
|
model: torch.nn.Module, |
|
|
data: torch.utils.data.DataLoader, |
|
|
loss: torch.nn.Module, |
|
|
optimizer: torch.optim.Optimizer, |
|
|
controller: Setup, |
|
|
|
|
|
scheduler: torch.optim.lr_scheduler.LRScheduler = None, |
|
|
) -> float: |
|
|
""" |
|
|
Performs a single training step including forward pass, loss calculation, backward pass, |
|
|
and optimization step. |
|
|
|
|
|
Parameters: |
|
|
model (torch.nn.Module): The model to be trained. |
|
|
data (torch.utils.data.DataLoader): DataLoader providing the training data. |
|
|
loss (torch.nn.Module): Loss function to be used. |
|
|
optimizer (torch.optim.Optimizer): Optimizer used for gradient updates. |
|
|
controller (Setup): The setup object containing configuration and state. |
|
|
scheduler (torch.optim.lr_scheduler._LRScheduler, optional): Learning rate scheduler to adjust the learning rate. |
|
|
Returns: |
|
|
float: The mean loss value for this training step. |
|
|
""" |
|
|
|
|
|
model.to(controller.device) |
|
|
model.train() |
|
|
|
|
|
|
|
|
losses = list() |
|
|
|
|
|
with HookMonitor(model, controller.watcher['activations'], controller.logger) as hooks: |
|
|
with tqdm.tqdm(data, desc=f'\rTraining epoch {controller.epoch}', leave=True) as pbar: |
|
|
pbar: torch.DataLoader |
|
|
hooks: HookMonitor |
|
|
|
|
|
for i, element in enumerate(pbar): |
|
|
|
|
|
|
|
|
args = tuple() |
|
|
if len(element) == 2: |
|
|
|
|
|
x, y = element |
|
|
x_m, y_m = None, None |
|
|
elif len(element) == 3: |
|
|
|
|
|
x, y, x_m = element |
|
|
y_m = None |
|
|
elif len(element) == 4: |
|
|
|
|
|
x, y, x_m, y_m = element |
|
|
elif len(element) > 4: |
|
|
|
|
|
x, y = element[0], element[1] |
|
|
x_m, y_m = element[2], element[3] |
|
|
args = element[4:] |
|
|
else: |
|
|
raise ValueError("DataLoader elements must have at least two elements.") |
|
|
|
|
|
|
|
|
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True) |
|
|
optimizer.zero_grad() |
|
|
if x_m is not None: |
|
|
x_m = x_m.to(controller.device, non_blocking=True) |
|
|
if y_m is not None: |
|
|
y_m = y_m.to(controller.device, non_blocking=True) |
|
|
|
|
|
|
|
|
if controller.autoscaler is not None: |
|
|
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), device_type=controller.device.type): |
|
|
|
|
|
y_hat = model(x, x_m, *args) if x_m is not None else model(x) |
|
|
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) |
|
|
|
|
|
controller.autoscaler.scale(loss_metric).backward() |
|
|
controller.autoscaler.step(optimizer) |
|
|
controller.autoscaler.update() |
|
|
else: |
|
|
|
|
|
y_hat = model(x, x_m, *args) if x_m is not None else model(x) |
|
|
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) |
|
|
|
|
|
loss_metric.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
losses.append(loss_metric.item()) |
|
|
|
|
|
|
|
|
if controller.replay_id[0] == i: |
|
|
controller.register_replay(predicted=y_hat, target=y, mask=y_m) |
|
|
|
|
|
|
|
|
losses = np.array(losses) |
|
|
mean_loss = float(np.mean(losses)) |
|
|
|
|
|
|
|
|
|
|
|
for name, parameter in model.named_parameters(): |
|
|
controller.register(name, parameter) |
|
|
|
|
|
|
|
|
controller.register('loss', mean_loss) |
|
|
|
|
|
|
|
|
for layer_name, layer_stats in hooks.get_stats().items(): |
|
|
for func_name, item in layer_stats.items(): |
|
|
controller.register(f'{func_name}/{layer_name}', torch.Tensor([item])[0]) |
|
|
|
|
|
|
|
|
|
|
|
if scheduler is not None: |
|
|
controller.register('lr', scheduler.get_last_lr()[0]) |
|
|
scheduler.step() |
|
|
|
|
|
|
|
|
controller.logger.info(f"Epoch [{controller.epoch}]: loss = {mean_loss:.8f}") |
|
|
|
|
|
|
|
|
controller.check(model, optimizer, scheduler) |
|
|
|
|
|
return mean_loss |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def validation_step( |
|
|
|
|
|
model: torch.nn.Module, |
|
|
data: torch.utils.data.DataLoader, |
|
|
loss: torch.nn.Module, |
|
|
controller: Setup, |
|
|
additional_metrics: dict = (), |
|
|
) -> dict: |
|
|
""" |
|
|
Performs a single validation step including forward pass and loss calculation. |
|
|
|
|
|
Parameters: |
|
|
model (torch.nn.Module): The model to be validated. |
|
|
data (torch.utils.data.DataLoader): DataLoader providing the validation data. |
|
|
loss (torch.nn.Module): Loss function to be used. |
|
|
controller (Setup): The setup object containing configuration and state. |
|
|
additional_metrics (dict): Additional metrics to calculate for each epoch. |
|
|
Returns: |
|
|
float: The mean loss value for this validation step. |
|
|
""" |
|
|
|
|
|
model.to(controller.device) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
losses = list() |
|
|
metrics: dict[str, list | float] = {name: list() for name in additional_metrics} |
|
|
|
|
|
with torch.no_grad(): |
|
|
with tqdm.tqdm(data, desc=f'\rValidation epoch {controller.epoch}', leave=True) as pbar: |
|
|
pbar: torch.DataLoader |
|
|
for element in pbar: |
|
|
|
|
|
if len(element) == 2: |
|
|
|
|
|
x, y = element |
|
|
x_m, y_m = None, None |
|
|
args = tuple() |
|
|
elif len(element) == 3: |
|
|
|
|
|
x, y, x_m = element |
|
|
y_m = None |
|
|
args = tuple() |
|
|
elif len(element) == 4: |
|
|
|
|
|
x, y, x_m, y_m = element |
|
|
elif len(element) > 4: |
|
|
|
|
|
x, y = element[0], element[1] |
|
|
x_m, y_m = element[2], element[3] |
|
|
args = element[4:] |
|
|
else: |
|
|
raise ValueError("DataLoader elements must have at least two elements.") |
|
|
|
|
|
|
|
|
x, y = x.to(controller.device, non_blocking=True), y.to(controller.device, non_blocking=True) |
|
|
if x_m is not None: |
|
|
x_m = x_m.to(controller.device, non_blocking=True) |
|
|
if y_m is not None: |
|
|
y_m = y_m.to(controller.device, non_blocking=True) |
|
|
|
|
|
|
|
|
if controller.autoscaler is not None: |
|
|
with torch.amp.autocast(enabled=(controller.device.type == 'cuda'), |
|
|
device_type=controller.device.type): |
|
|
|
|
|
y_hat = model(x, x_m, *args) if x_m is not None else model(x) |
|
|
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) |
|
|
|
|
|
|
|
|
if additional_metrics: |
|
|
for name, additional_metric in additional_metrics.items(): |
|
|
metrics[name].append(additional_metric(y_hat, y, y_m).item()) |
|
|
else: |
|
|
|
|
|
y_hat = model(x, x_m, *args) if x_m is not None else model(x) |
|
|
loss_metric = loss(y_hat, y, y_m) if y_m is not None else loss(y_hat, y) |
|
|
|
|
|
|
|
|
if additional_metrics: |
|
|
for name, additional_metric in additional_metrics.items(): |
|
|
metrics[name].append(additional_metric(y_hat, y, y_m).item()) |
|
|
|
|
|
|
|
|
losses.append(loss_metric.item()) |
|
|
|
|
|
|
|
|
losses = np.array(losses) |
|
|
mean_loss = float(np.mean(losses)) |
|
|
|
|
|
|
|
|
for name, variable in metrics.items(): |
|
|
metrics[name] = float(np.mean(variable)) |
|
|
metrics['loss'] = mean_loss |
|
|
|
|
|
|
|
|
controller.register("val_loss", mean_loss) |
|
|
|
|
|
controller.logger.info(f"Epoch [{controller.epoch}]: val_loss = {mean_loss:.8f}") |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
|
|
|
|