| import torch |
| import torch.nn as nn |
| import numpy as np |
| from typing import Tuple, Callable, Optional |
|
|
|
|
| def normalize_data(data: torch.Tensor, mean: Optional[torch.Tensor] = None, |
| std: Optional[torch.Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Normalize data to zero mean and unit variance. |
| |
| Args: |
| data: Input tensor to normalize |
| mean: Optional precomputed mean (if None, computed from data) |
| std: Optional precomputed std (if None, computed from data) |
| |
| Returns: |
| Tuple of (normalized_data, mean, std) |
| """ |
| if mean is None: |
| mean = data.mean() |
| if std is None: |
| std = data.std() |
| |
| |
| std = torch.clamp(std, min=1e-8) |
| |
| normalized = (data - mean) / std |
| return normalized, mean, std |
|
|
|
|
| def denormalize_data(normalized_data: torch.Tensor, mean: torch.Tensor, |
| std: torch.Tensor) -> torch.Tensor: |
| """ |
| Denormalize data using provided mean and std. |
| |
| Args: |
| normalized_data: Normalized tensor |
| mean: Mean used for normalization |
| std: Standard deviation used for normalization |
| |
| Returns: |
| Denormalized tensor |
| """ |
| return normalized_data * std + mean |
|
|
|
|
| def mean_pooling(x: torch.Tensor, dim: int = 1) -> torch.Tensor: |
| """ |
| Apply mean pooling along specified dimension. |
| |
| Args: |
| x: Input tensor |
| dim: Dimension to pool over |
| |
| Returns: |
| Mean-pooled tensor |
| """ |
| return x.mean(dim=dim) |
|
|
|
|
| def masked_mean_pooling(x: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor: |
| """ |
| Apply mean pooling along specified dimension, excluding masked (padded) positions. |
| |
| Args: |
| x: Input tensor (B, seq_len, dim) |
| mask: Boolean mask tensor (B, seq_len) where True indicates real data |
| dim: Dimension to pool over (default: 1, sequence dimension) |
| |
| Returns: |
| Mean-pooled tensor excluding masked positions |
| """ |
| if mask.dim() == 2 and x.dim() == 3: |
| |
| mask = mask.unsqueeze(-1) |
| |
| |
| masked_x = x * mask.float() |
| |
| |
| sum_x = masked_x.sum(dim=dim) |
| |
| |
| count = mask.float().sum(dim=dim) |
| |
| |
| count = torch.clamp(count, min=1e-8) |
| |
| |
| return sum_x / count |
|
|
|
|
|
|
|
|
| def pad_sequences(sequences: list, max_length: Optional[int] = None, |
| padding_value: float = -1e9) -> torch.Tensor: |
| """ |
| Pad sequences to the same length with a configurable padding value. |
| |
| Args: |
| sequences: List of tensors with different lengths |
| max_length: Maximum length to pad to (if None, use longest sequence) |
| padding_value: Value to use for padding (default: -1e9, avoids conflict with meaningful zeros) |
| |
| Returns: |
| Padded tensor of shape (batch_size, max_length, dim) |
| """ |
| if max_length is None: |
| max_length = max(seq.size(0) for seq in sequences) |
| |
| batch_size = len(sequences) |
| dim = sequences[0].size(-1) |
| |
| padded = torch.full((batch_size, max_length, dim), padding_value, |
| dtype=sequences[0].dtype, device=sequences[0].device) |
| |
| for i, seq in enumerate(sequences): |
| length = min(seq.size(0), max_length) |
| padded[i, :length] = seq[:length] |
| |
| return padded |
|
|
|
|
| def create_padding_mask(sequences: list, max_length: Optional[int] = None) -> torch.Tensor: |
| """ |
| Create padding mask for sequences. |
| |
| Args: |
| sequences: List of tensors with different lengths |
| max_length: Maximum length (if None, use longest sequence) |
| |
| Returns: |
| Boolean mask tensor where True indicates real data, False indicates padding |
| """ |
| if max_length is None: |
| max_length = max(seq.size(0) for seq in sequences) |
| |
| batch_size = len(sequences) |
| mask = torch.zeros(batch_size, max_length, dtype=torch.bool, device=sequences[0].device) |
| |
| for i, seq in enumerate(sequences): |
| length = min(seq.size(0), max_length) |
| mask[i, :length] = True |
| |
| return mask |
|
|
|
|
|
|
|
|
| def compute_rmse(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
| """ |
| Compute Root Mean Square Error. |
| |
| Args: |
| predictions: Predicted values |
| targets: True target values |
| |
| Returns: |
| RMSE value |
| """ |
| mse = torch.mean((predictions - targets) ** 2) |
| return torch.sqrt(mse).item() |
|
|
|
|
| def compute_mae(predictions: torch.Tensor, targets: torch.Tensor) -> float: |
| """ |
| Compute Mean Absolute Error. |
| |
| Args: |
| predictions: Predicted values |
| targets: True target values |
| |
| Returns: |
| MAE value |
| """ |
| mae = torch.mean(torch.abs(predictions - targets)) |
| return mae.item() |
|
|
|
|
| class EarlyStopping: |
| """ |
| Early stopping utility to stop training when validation loss stops improving. |
| """ |
| |
| def __init__(self, patience: int = 5, min_delta: float = 0.0, |
| restore_best_weights: bool = True): |
| """ |
| Args: |
| patience: Number of epochs with no improvement after which training will be stopped |
| min_delta: Minimum change in monitored quantity to qualify as improvement |
| restore_best_weights: Whether to restore model weights from the best epoch |
| """ |
| self.patience = patience |
| self.min_delta = min_delta |
| self.restore_best_weights = restore_best_weights |
| |
| self.best_loss = float('inf') |
| self.counter = 0 |
| self.best_weights = None |
| |
| def __call__(self, val_loss: float, model: nn.Module) -> bool: |
| """ |
| Check if training should be stopped. |
| |
| Args: |
| val_loss: Current validation loss |
| model: Model to potentially save weights for |
| |
| Returns: |
| True if training should be stopped, False otherwise |
| """ |
| if val_loss < self.best_loss - self.min_delta: |
| self.best_loss = val_loss |
| self.counter = 0 |
| if self.restore_best_weights: |
| self.best_weights = model.state_dict().copy() |
| else: |
| self.counter += 1 |
| |
| if self.counter >= self.patience: |
| if self.restore_best_weights and self.best_weights is not None: |
| model.load_state_dict(self.best_weights) |
| return True |
| |
| return False |