| import datetime |
| import itertools |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.optim as optim |
| from torch.utils.data import DataLoader, DistributedSampler |
| import torch.nn.functional as F |
| import torch.distributed as dist |
| import torch.multiprocessing as mp |
| from torch.nn.parallel import DistributedDataParallel as DDP |
| import random |
| import numpy as np |
| from typing import Tuple, List, Dict, Any, Union, Optional |
| import argparse |
| import json |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
| import random |
| import os |
| import pickle |
| from typing import Dict, List, Union, Optional, Tuple |
| from pathlib import Path |
| from dataclasses import dataclass |
|
|
| import sys |
| from models.time_rcd.ts_encoder_bi_bias import TimeSeriesEncoder |
| from models.time_rcd.time_rcd_config import TimeRCDConfig, default_config |
|
|
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| @dataclass |
| class PretrainBatch: |
| """Batch structure for pretraining tasks.""" |
| time_series: torch.Tensor |
| labels: torch.Tensor |
| masked_time_series: torch.Tensor |
| mask_indices: torch.Tensor |
| |
| class ChatTSAnomalyPretrainDataset(Dataset): |
| def __init__(self, |
| dataset_dir: str, |
| filename: str, |
| split: str = 'train', |
| train_ratio: float = 0.95, |
| seed: int = 42): |
| file_path = os.path.join(dataset_dir, filename) |
| with open(file_path, 'rb') as f: |
| dataset = pickle.load(f) |
| random.seed(seed) |
| indices = list(range(len(dataset))) |
| random.shuffle(indices) |
| num_train = int(len(dataset) * train_ratio) |
| if split == 'train': |
| selected_indices = indices[:num_train] |
| elif split == 'test': |
| selected_indices = indices[num_train:] |
| else: |
| raise ValueError("split must be 'train' or 'test'") |
| self.data = [dataset[i] for i in selected_indices] |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| sample = self.data[idx] |
| time_series = torch.tensor(sample['time_series'], dtype=torch.float32) |
| normal_time_series = torch.tensor(sample['normal_time_series'], dtype=torch.float32) |
| labels = torch.tensor(sample['labels'], dtype=torch.long) |
| attribute = sample['attribute'] |
| return time_series, normal_time_series, labels, attribute |
|
|
| class TimeSeriesPretrainModel(nn.Module): |
| """Model for time series pretraining with masked reconstruction and anomaly detection.""" |
| |
| def __init__(self, config: TimeRCDConfig): |
| super().__init__() |
| self.config = config |
| |
| |
| ts_config = config.ts_config |
| self.ts_encoder = TimeSeriesEncoder( |
| d_model=ts_config.d_model, |
| d_proj=ts_config.d_proj, |
| patch_size=ts_config.patch_size, |
| num_layers=ts_config.num_layers, |
| num_heads=ts_config.num_heads, |
| d_ff_dropout=ts_config.d_ff_dropout, |
| use_rope=ts_config.use_rope, |
| num_features=ts_config.num_features, |
| activation=ts_config.activation |
| ) |
| |
| |
| self.reconstruction_head = nn.Sequential( |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj * 4), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj * 4, config.ts_config.d_proj * 4), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj * 4, 1) |
| ) |
| |
| |
| self.anomaly_head = nn.Sequential( |
| nn.Linear(config.ts_config.d_proj, config.ts_config.d_proj // 2), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.ts_config.d_proj // 2, 2) |
| ) |
|
|
| self.anomaly_head.apply(self._init_weights) |
| self.reconstruction_head.apply(self._init_weights) |
|
|
| def _init_weights(self, module): |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_normal_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward(self, time_series: torch.Tensor, mask: Optional[torch.Tensor] = None): |
| """Forward pass through the encoder.""" |
| local_embeddings = self.ts_encoder(time_series, mask) |
| return local_embeddings |
|
|
| def masked_reconstruction_loss(self, |
| local_embeddings: torch.Tensor, |
| original_time_series: torch.Tensor, |
| mask: torch.Tensor |
| ) -> torch.Tensor: |
| """Compute masked reconstruction loss.""" |
| batch_size, seq_len, num_features = original_time_series.shape |
| patch_size = self.config.ts_config.patch_size |
| |
| |
| mask = mask.bool() |
| |
| |
| |
| |
| reconstructed = self.reconstruction_head(local_embeddings) |
| reconstructed = reconstructed.view(batch_size, seq_len, num_features) |
| |
| |
| mask_expanded = mask.unsqueeze(-1).expand(-1, -1, num_features) |
| reconstruction_loss = F.mse_loss( |
| reconstructed[mask_expanded], |
| original_time_series[mask_expanded] |
| ) |
| return reconstruction_loss |
| |
| def anomaly_detection_loss(self, |
| local_embeddings: torch.Tensor, |
| labels: torch.Tensor) -> torch.Tensor: |
| """Compute anomaly detection loss for each timestep.""" |
| |
| logits = self.anomaly_head(local_embeddings) |
| logits = torch.mean(logits, dim=-2) |
| |
| |
| |
| batch_size, seq_len, _ = logits.shape |
| logits = logits.view(-1, 2) |
| labels = labels.view(-1) |
| labels = (labels > 0.5).long() |
| |
| valid_mask = (labels != -1) |
| |
| |
| if valid_mask.sum() > 0: |
| anomaly_loss = F.cross_entropy( |
| logits[valid_mask], |
| labels[valid_mask] |
| ) |
| else: |
| anomaly_loss = torch.tensor(0.0, device=logits.device) |
| |
| return anomaly_loss |
|
|
|
|
| def create_random_mask(time_series: torch.Tensor, |
| attention_mask: torch.Tensor, |
| mask_ratio: float = 0.15) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Create random mask for time series patches, only masking valid sequence parts.""" |
| batch_size, seq_len, num_features = time_series.shape |
| patch_size = default_config.ts_config.patch_size |
| |
| mask = torch.zeros(batch_size, seq_len) |
| |
| for i in range(batch_size): |
| |
| valid_length = attention_mask[i].sum().item() |
| |
| |
| num_valid_patches = (valid_length - 1) // patch_size + 1 |
| num_masked = int(num_valid_patches * mask_ratio) |
| |
| if num_masked > 0: |
| |
| masked_patches = torch.randperm(num_valid_patches)[:num_masked] |
| for j in masked_patches: |
| start_idx = j * patch_size |
| end_idx = min((j + 1) * patch_size, valid_length) |
| mask[i, start_idx:end_idx] = 1 |
| |
| |
| masked_time_series = time_series.clone() |
| mask_indices = mask.bool() & attention_mask |
| mask_expanded = mask_indices.unsqueeze(-1).expand(-1, -1, num_features) |
| |
| masked_time_series[mask_expanded] = 0.0 |
| |
| |
| |
| mask = mask * attention_mask.float() |
| |
| return masked_time_series, mask |
|
|
|
|
| def collate_fn(batch): |
| """Collate function for pretraining dataset.""" |
| time_series_list, normal_time_series_list, labels_list, attribute_list = zip(*batch) |
| |
| |
| if time_series_list[0].ndim == 1: |
| time_series_tensors = [ts.unsqueeze(-1) for ts in time_series_list] |
| normal_time_series_tensors = [nts.unsqueeze(-1) for nts in normal_time_series_list] |
| else: |
| time_series_tensors = [ts for ts in time_series_list] |
| normal_time_series_tensors = [nts for nts in normal_time_series_list] |
|
|
| |
| concatenated = torch.cat(time_series_tensors, dim=0) |
| mean = concatenated.mean(dim=0, keepdim=True) |
| std = concatenated.std(dim=0, keepdim=True) |
| std = std + 1e-4 |
| time_series_tensors_std = [(ts - mean) / std for ts in time_series_tensors] |
| normal_time_series_tensors_std = [(nts - mean) / std for nts in normal_time_series_tensors] |
| time_series_tensors = time_series_tensors_std |
| normal_time_series_tensors = normal_time_series_tensors_std |
|
|
| |
| labels = [label for label in labels_list] |
| |
| padded_time_series = torch.nn.utils.rnn.pad_sequence( |
| time_series_tensors, batch_first=True, padding_value=0.0 |
| ) |
| padded_normal_time_series = torch.nn.utils.rnn.pad_sequence( |
| normal_time_series_tensors, batch_first=True, padding_value=0.0 |
| ) |
| padded_labels = torch.nn.utils.rnn.pad_sequence( |
| labels, batch_first=True, padding_value=-1 |
| ) |
|
|
| sequence_lengths = [ts.size(0) for ts in time_series_tensors] |
| B, max_seq_len, num_features = padded_time_series.shape |
| attention_mask = torch.zeros(B, max_seq_len, dtype=torch.bool) |
| for i, length in enumerate(sequence_lengths): |
| attention_mask[i, :length] = True |
| |
| |
| masked_time_series, mask = create_random_mask(padded_time_series, attention_mask) |
| |
| return { |
| 'time_series': padded_time_series, |
| 'normal_time_series': padded_normal_time_series, |
| 'masked_time_series': masked_time_series, |
| 'mask': mask, |
| 'labels': padded_labels, |
| 'attention_mask': attention_mask, |
| 'attribute': attribute_list |
| } |
|
|
|
|
| def set_seed(seed: int) -> None: |
| """Set random seed for reproducibility.""" |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
| torch.backends.cudnn.deterministic = True |
| torch.backends.cudnn.benchmark = False |
|
|
|
|
| def setup_distributed(rank: int, world_size: int, config: TimeRCDConfig) -> None: |
| """Setup distributed training environment.""" |
| os.environ['MASTER_ADDR'] = 'localhost' |
| os.environ['MASTER_PORT'] = config.dist_port |
| |
| try: |
| dist.init_process_group( |
| "nccl", |
| rank=rank, |
| world_size=world_size, |
| timeout=datetime.timedelta(minutes=30) |
| ) |
| torch.cuda.set_device(rank) |
| |
| if rank == 0: |
| print(f"Successfully initialized distributed training on rank {rank} with world size {world_size}") |
| |
| except Exception as e: |
| print(f"Rank {rank}: Initialization failed with error: {e}") |
| raise e |
|
|
|
|
| def cleanup_distributed() -> None: |
| """Clean up distributed training environment.""" |
| if dist.is_initialized(): |
| dist.destroy_process_group() |
|
|
|
|
| def evaluate_epoch(test_loader: DataLoader, |
| model: nn.Module, |
| config: TimeRCDConfig, |
| device: torch.device, |
| rank: int) -> float: |
| """Evaluate model on test dataset.""" |
| model.eval() |
| total_loss = 0.0 |
| total_recon_loss = 0.0 |
| total_anomaly_loss = 0.0 |
| num_batches = 0 |
| |
| with torch.no_grad(): |
| for batch in itertools.islice(test_loader, min(len(test_loader), config.test_batch_limit)): |
| |
| time_series = batch['time_series'].to(device) |
| masked_time_series = batch['masked_time_series'].to(device) |
| mask = batch['mask'].to(device) |
| labels = batch['labels'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| |
| |
| local_embeddings = model(masked_time_series, attention_mask & (~mask.bool())) |
| |
| |
| recon_loss = model.module.masked_reconstruction_loss( |
| local_embeddings, time_series, mask |
| ) |
| anomaly_loss = model.module.anomaly_detection_loss(local_embeddings, labels) |
| |
| total_loss_batch = recon_loss + anomaly_loss |
| total_loss += total_loss_batch.item() |
| total_recon_loss += recon_loss.item() |
| total_anomaly_loss += anomaly_loss.item() |
| num_batches += 1 |
| |
| avg_loss = total_loss / num_batches if num_batches > 0 else 0.0 |
| avg_recon_loss = total_recon_loss / num_batches if num_batches > 0 else 0.0 |
| avg_anomaly_loss = total_anomaly_loss / num_batches if num_batches > 0 else 0.0 |
| |
| if rank == 0: |
| print(f"Validation Results:") |
| print(f" Average Total Loss: {avg_loss:.4f}") |
| print(f" Average Recon Loss: {avg_recon_loss:.4f}") |
| print(f" Average Anomaly Loss: {avg_anomaly_loss:.4f}") |
| |
| return avg_loss |
|
|
|
|
| def train_epoch(train_loader: DataLoader, |
| model: nn.Module, |
| optimizer: optim.Optimizer, |
| config: TimeRCDConfig, |
| device: torch.device, |
| epoch: int, |
| rank: int, |
| scaler: Optional[torch.cuda.amp.GradScaler] = None) -> float: |
| """Train for one epoch with multiple pretraining tasks.""" |
| model.train() |
| total_loss = 0.0 |
| total_recon_loss = 0.0 |
| total_anomaly_loss = 0.0 |
| num_batches = 0 |
| |
| for batch_idx, batch in enumerate(train_loader): |
| if batch_idx % 10 == 0: |
| torch.cuda.empty_cache() |
|
|
| optimizer.zero_grad() |
| |
| |
| time_series = batch['time_series'].to(device) |
| masked_time_series = batch['masked_time_series'].to(device) |
| mask = batch['mask'].to(device) |
| labels = batch['labels'].to(device) |
| attention_mask = batch['attention_mask'].to(device) |
| |
| if config.mixed_precision and scaler is not None: |
| with torch.amp.autocast('cuda'): |
| local_embeddings = model(masked_time_series, attention_mask & (~mask.bool())) |
| |
| recon_loss = model.module.masked_reconstruction_loss( |
| local_embeddings, time_series, mask |
| ) |
| anomaly_loss = model.module.anomaly_detection_loss(local_embeddings, labels) |
| |
| total_loss_batch = recon_loss + anomaly_loss |
| scaler.scale(total_loss_batch).backward() |
| scaler.step(optimizer) |
| scaler.update() |
| else: |
| local_embeddings = model(masked_time_series, attention_mask & (~mask.bool())) |
| |
| recon_loss = model.module.masked_reconstruction_loss( |
| local_embeddings, time_series, mask |
| ) |
| anomaly_loss = model.module.anomaly_detection_loss(local_embeddings, labels) |
| |
| total_loss_batch = recon_loss + anomaly_loss |
| total_loss_batch.backward() |
| optimizer.step() |
| |
| |
| total_loss += total_loss_batch.item() |
| total_recon_loss += recon_loss.item() |
| total_anomaly_loss += anomaly_loss.item() |
| num_batches += 1 |
| |
| |
| if rank == 0 and batch_idx % config.log_freq == 0: |
| print(f"Epoch {epoch}, Batch {batch_idx}/{len(train_loader)}") |
| print(f" Total Loss: {total_loss_batch.item():.4f}") |
| print(f" Recon Loss: {recon_loss.item():.4f}") |
| print(f" Anomaly Loss: {anomaly_loss.item():.4f}") |
| |
| avg_loss = total_loss / num_batches |
| avg_recon_loss = total_recon_loss / num_batches |
| avg_anomaly_loss = total_anomaly_loss / num_batches |
| |
| if rank == 0: |
| print(f"Epoch {epoch} completed:") |
| print(f" Average Total Loss: {avg_loss:.4f}") |
| print(f" Average Recon Loss: {avg_recon_loss:.4f}") |
| print(f" Average Anomaly Loss: {avg_anomaly_loss:.4f}") |
| |
| return avg_loss |
|
|
|
|
| def save_checkpoint(model: nn.Module, |
| optimizer: optim.Optimizer, |
| config: TimeRCDConfig, |
| epoch: int, |
| avg_loss: float, |
| rank: int = 0, |
| is_best: bool = False) -> None: |
| """Save model checkpoint.""" |
| if rank != 0: |
| return |
| |
| |
| if hasattr(model, 'module'): |
| model_state_dict = model.module.state_dict() |
| else: |
| model_state_dict = model.state_dict() |
| |
| checkpoint = { |
| 'epoch': epoch, |
| 'model_state_dict': model_state_dict, |
| 'optimizer_state_dict': optimizer.state_dict(), |
| 'loss': avg_loss, |
| 'config': config.to_dict() |
| } |
| |
| os.makedirs(config.checkpoint_dir, exist_ok=True) |
| |
| |
| latest_path = os.path.join(config.checkpoint_dir, "pretrain_checkpoint_latest.pth") |
| torch.save(checkpoint, latest_path) |
| |
| |
| if epoch % config.save_freq == 0 or epoch == config.num_epochs - 1: |
| save_path = os.path.join(config.checkpoint_dir, f"pretrain_checkpoint_epoch_{epoch}.pth") |
| torch.save(checkpoint, save_path) |
| print(f"Checkpoint saved to {save_path} (epoch {epoch}, loss: {avg_loss:.4f})") |
|
|
| |
| if is_best: |
| best_path = os.path.join(config.checkpoint_dir, "pretrain_checkpoint_best.pth") |
| torch.save(checkpoint, best_path) |
| print(f"New best model saved to {best_path} (epoch {epoch}, val_loss: {avg_loss:.4f})") |
|
|
| |
| if hasattr(model, 'module'): |
| ts_encoder_state = model.module.ts_encoder.state_dict() |
| else: |
| ts_encoder_state = model.ts_encoder.state_dict() |
| |
| best_encoder_path = os.path.join(config.checkpoint_dir, "pretrained_ts_encoder.pth") |
| torch.save(ts_encoder_state, best_encoder_path) |
| print(f"Best pretrained time series encoder saved to {best_encoder_path}") |
|
|
|
|
| def train_multiple_datasets(dataset_filenames: List[str], config: TimeRCDConfig) -> None: |
| """Train on multiple datasets sequentially with model weight continuation.""" |
| print(f'\n{"=" * 50}') |
| print(f"Starting Multi-Dataset Sequential Training") |
| print(f"Number of datasets: {len(dataset_filenames)}") |
| print(f"Datasets: {dataset_filenames}") |
| print("Training Mode: Continuous (model weights carried over between datasets)") |
| print("=" * 50) |
| |
| |
| gpu_ids = [int(x.strip()) for x in config.cuda_devices.split(',')] |
| world_size = len(gpu_ids) |
| |
| |
| os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda_devices |
| |
| |
| global_checkpoint_path = None |
| |
| |
| for dataset_idx, filename in enumerate(dataset_filenames): |
| print(f"\n{'='*50}") |
| print(f"Training on Dataset {dataset_idx + 1}/{len(dataset_filenames)}: {filename}") |
| if global_checkpoint_path is not None: |
| print(f"Continuing from previous dataset's trained model...") |
| print(f"{'='*50}") |
|
|
| batch_size_list = [256, 64, 64, 32, 32, 16, 16, 48, |
| 16, 16, 16, 32, 16, 16, 16, 16, |
| 16, 16, 16, 16, 12, 12, 12, 16, |
| 12, 12, 12, 12, 12, 12, 12, 16, |
| 12, 12, 12, 12, 12, 12, 12, 12, |
| 12, 12, 12, 12, 12, 12, 12, 12, |
| 12, 12, 12, 12, 12, 12, 12, 8] |
| num_features = int(os.path.splitext(filename)[0].split('_')[-1]) |
| print(f"Using batch size: {batch_size_list[num_features - 1] if num_features <= len(batch_size_list) else batch_size_list[-1]} for {filename}") |
| if num_features <= len(batch_size_list): |
| batch_size = batch_size_list[num_features - 1] |
| else: |
| batch_size = batch_size_list[-1] |
| config.batch_size = batch_size |
|
|
| |
| original_checkpoint_dir = config.checkpoint_dir |
| config.checkpoint_dir = os.path.join(original_checkpoint_dir, f"{filename}") |
| os.makedirs(config.checkpoint_dir, exist_ok=True) |
| |
| |
| config.continuation_checkpoint = global_checkpoint_path |
| |
| config.ts_config.num_features = num_features |
| if world_size == 1: |
| |
| print(f"Running single GPU pretraining for {filename}...") |
| train_worker(0, 1, config, filename) |
| else: |
| |
| print(f"Running distributed pretraining for {filename}...") |
| mp.spawn( |
| train_worker, |
| args=(world_size, config, filename), |
| nprocs=world_size, |
| join=True |
| ) |
| |
| |
| global_checkpoint_path = os.path.join(config.checkpoint_dir, "pretrain_checkpoint_best.pth") |
| |
| |
| config.checkpoint_dir = original_checkpoint_dir |
| |
| print(f"Completed training on dataset: {filename}") |
| if dataset_idx < len(dataset_filenames) - 1: |
| print(f"Model weights will be loaded for next dataset training...") |
| |
| print(f"\n{'='*50}") |
| print("Multi-Dataset Sequential Training Completed!") |
| print(f"All {len(dataset_filenames)} datasets have been processed with model continuation.") |
| print(f"{'='*50}") |
|
|
|
|
| def train_worker(rank: int, world_size: int, config: TimeRCDConfig, filename: str = None) -> None: |
| """Training worker function for each process.""" |
| print(f"Running DDP on rank {rank} with world_size {world_size} for dataset: {filename}") |
| |
| |
| setup_distributed(rank, world_size, config) |
| |
| |
| device = torch.device(f"cuda:{rank}") |
| |
| |
| set_seed(config.seed + rank) |
| |
| try: |
| |
| model = TimeSeriesPretrainModel(config).to(device) |
| |
| |
| checkpoint = None |
| if hasattr(config, 'continuation_checkpoint') and config.continuation_checkpoint and os.path.exists(config.continuation_checkpoint): |
| if rank == 0: |
| print(f"Loading checkpoint from previous dataset: {config.continuation_checkpoint}") |
| checkpoint = torch.load(config.continuation_checkpoint, map_location=device) |
| |
| |
| state_dict = checkpoint['model_state_dict'] |
| |
| |
| if any(key.startswith('module.') for key in state_dict.keys()): |
| new_state_dict = {} |
| for key, value in state_dict.items(): |
| if key.startswith('module.'): |
| new_key = key[7:] |
| new_state_dict[new_key] = value |
| else: |
| new_state_dict[key] = value |
| state_dict = new_state_dict |
| |
| model.load_state_dict(state_dict, strict=False) |
| if rank == 0: |
| print(f"Successfully loaded model weights from previous dataset training") |
| |
| |
| |
| model = DDP(model, device_ids=[rank]) |
| |
| |
| optimizer = optim.AdamW( |
| model.parameters(), |
| lr=config.learning_rate, |
| weight_decay=config.weight_decay |
| ) |
| |
| |
| if checkpoint is not None and 'optimizer_state_dict' in checkpoint: |
| try: |
| optimizer.load_state_dict(checkpoint['optimizer_state_dict']) |
| if rank == 0: |
| print("Successfully loaded optimizer state from previous dataset training") |
| except Exception as e: |
| if rank == 0: |
| print(f"Warning: Could not load optimizer state: {e}") |
| print("Continuing with fresh optimizer state") |
| print("This is normal when model architecture or optimizer parameters change") |
| |
| |
| scaler = torch.amp.GradScaler() if config.mixed_precision else None |
| |
| |
| train_dataset = ChatTSAnomalyPretrainDataset(config.pretrain_data_path, filename, split="train") |
| test_dataset = ChatTSAnomalyPretrainDataset(config.pretrain_data_path, filename, split="test") |
| |
| |
| train_sampler = DistributedSampler( |
| train_dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=True |
| ) |
| |
| train_loader = DataLoader( |
| train_dataset, |
| batch_size=config.batch_size, |
| sampler=train_sampler, |
| collate_fn=collate_fn, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| |
| test_sampler = DistributedSampler( |
| test_dataset, |
| num_replicas=world_size, |
| rank=rank, |
| shuffle=False |
| ) |
| |
| test_loader = DataLoader( |
| test_dataset, |
| batch_size=config.batch_size, |
| sampler=test_sampler, |
| collate_fn=collate_fn, |
| num_workers=2, |
| pin_memory=True |
| ) |
| |
| |
| best_val_loss = float('inf') |
| patience_counter = 0 |
| early_stopping_patience = getattr(config, 'early_stopping_patience', 10) |
| |
| |
| if rank == 0: |
| dataset_name = filename if filename else "default" |
| continuation_info = "" |
| if hasattr(config, 'continuation_checkpoint') and config.continuation_checkpoint and os.path.exists(config.continuation_checkpoint): |
| continuation_info = " (continuing from previous dataset)" |
| print(f"Starting pretraining for {config.num_epochs} epochs on dataset {dataset_name}{continuation_info}...") |
| print(f"Total training batches per process: {len(train_loader)}") |
| print(f"Total validation batches per process: {min(config.test_batch_limit, len(test_loader))}") |
| print(f"Early stopping patience: {early_stopping_patience} epochs") |
| print(f"Tasks: Masked Reconstruction + Anomaly Detection") |
| |
| for epoch in range(config.num_epochs): |
| |
| train_sampler.set_epoch(epoch) |
| test_sampler.set_epoch(epoch) |
| |
| |
| avg_train_loss = train_epoch(train_loader, model, optimizer, |
| config, device, epoch, rank, scaler) |
| |
| |
| avg_val_loss = evaluate_epoch(test_loader, model, config, device, rank) |
| |
| |
| is_best = avg_val_loss < best_val_loss |
| if is_best: |
| best_val_loss = avg_val_loss |
| patience_counter = 0 |
| if rank == 0: |
| print(f"New best validation loss: {best_val_loss:.4f}") |
| else: |
| patience_counter += 1 |
| if rank == 0: |
| print(f"Validation loss did not improve. Patience: {patience_counter}/{early_stopping_patience}") |
| |
| |
| save_checkpoint(model, optimizer, config, epoch, avg_val_loss, rank, is_best) |
| |
| |
| if patience_counter >= early_stopping_patience: |
| if rank == 0: |
| print(f"Early stopping triggered after {epoch + 1} epochs") |
| print(f"Best validation loss: {best_val_loss:.4f}") |
| break |
|
|
| |
| finally: |
| |
| cleanup_distributed() |
|
|
|
|
| def main() -> None: |
|
|
| |
| """Main function to launch distributed pretraining.""" |
| |
| config = default_config |
| |
| |
| config.num_epochs = 50 |
| config.learning_rate = 5e-4 |
| config.batch_size = 64 |
| config.ts_config.patch_size = 16 |
| config.checkpoint_dir = "checkpoints/" |
| config.cuda_devices = "3" |
| config.mixed_precision = False |
| config.dist_port = "16798" |
| config.early_stopping_patience = 7 |
| config.pretrain_data_path = "training_data/" |
|
|
| |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument('--mode', type=str, default='multi', choices=['multi', 'single']) |
| args = parser.parse_args() |
| |
| if args.mode == 'multi': |
| use_multi_dataset_training = True |
| else: |
| use_multi_dataset_training = False |
| |
| single_dataset_filename = "uni_data_0_1.pkl" |
| |
| dataset_filenames = [ |
| "dataset_0_1.pkl", |
| "dataset_1_1.pkl", |
| "dataset_2_1.pkl", |
| "dataset_7_8.pkl", |
| "dataset_8_12.pkl", |
| "dataset_9_16.pkl", |
| "dataset_10_20.pkl", |
| ] |
|
|
| |
| gpu_ids = [int(x.strip()) for x in config.cuda_devices.split(',')] |
| world_size = len(gpu_ids) |
| |
| print(f"Using GPUs: {gpu_ids}") |
| print(f"World size: {world_size}") |
| print(f"CUDA_VISIBLE_DEVICES: {os.environ.get('CUDA_VISIBLE_DEVICES', config.cuda_devices)}") |
| print("=" * 50) |
| print("AnomalyLLava Time Series Pretraining") |
| print("Tasks:") |
| print(" 1. Masked Reconstruction - Reconstruct masked time series patches") |
| print(" 2. Anomaly Detection - Binary classification of normal/anomalous series") |
| print("Features:") |
| print(" - Early stopping with validation loss monitoring") |
| print(" - Best model checkpoint saving") |
| print(f" - Early stopping patience: {config.early_stopping_patience} epochs") |
| if use_multi_dataset_training: |
| print(" - Sequential multi-dataset training with model weight continuation") |
| print("=" * 50) |
| |
| |
| os.makedirs(config.checkpoint_dir, exist_ok=True) |
| |
| if use_multi_dataset_training: |
| |
| print(f"Training Mode: Multi-Dataset Sequential ({len(dataset_filenames)} datasets)") |
| print(f"Datasets will be trained sequentially with model weight continuation") |
| train_multiple_datasets(dataset_filenames, config) |
| else: |
| |
| print(f"Training Mode: Single Dataset ({single_dataset_filename})") |
| |
| os.environ['CUDA_VISIBLE_DEVICES'] = config.cuda_devices |
| |
| num_features = int(os.path.splitext(single_dataset_filename)[0].split('_')[-1]) |
| config.ts_config.num_features = num_features |
| if world_size == 1: |
| |
| print("Running single GPU pretraining...") |
| train_worker(0, 1, config, single_dataset_filename) |
| else: |
| |
| print("Running distributed pretraining...") |
| mp.spawn( |
| train_worker, |
| args=(world_size, config, single_dataset_filename), |
| nprocs=world_size, |
| join=True |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| main() |