Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import functools | |
| import os | |
| import signal | |
| import torch | |
| import torch.distributed as dist | |
| import torch.utils.data | |
| from megatron.core import parallel_state | |
| from cosmos_predict1.utils import callback, distributed, ema, log, misc | |
| from cosmos_predict1.utils.checkpointer import Checkpointer | |
| from cosmos_predict1.utils.lazy_config import LazyConfig, instantiate | |
| from cosmos_predict1.utils.model import Model | |
| class Trainer: | |
| """The base trainer class. | |
| All trainers should inherit Trainer. It contains the basic functionality for model training | |
| (particularly suited for large-scale training), including data parallel (DDP/FSDP), model weight average (EMA), | |
| mixed-precision training (fp16/bf16). | |
| Attributes: | |
| checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. | |
| training_timer (misc.Timer): Timer object to time code blocks and functions. | |
| """ | |
| def __init__(self, config): | |
| """Constructor of the trainer. | |
| Args: | |
| config (Config): The config object for the codebase. | |
| """ | |
| super().__init__() | |
| self.config = config | |
| # Set up the distributed computing environment. | |
| with misc.timer("init_distributed"): | |
| distributed.init() | |
| # Set up parallel states. | |
| if hasattr(config.model, "context_parallel_size"): | |
| if config.model_parallel.context_parallel_size > 1: | |
| raise ValueError( | |
| "Both config.model.context_parallel_size and config.model_parallel.context_parallel_size are set. " | |
| "config.model.context_parallel_size is deprecated. Please only set config.model_parallel.context_parallel_size." | |
| ) | |
| else: | |
| log.critical( | |
| "Using deprecated config.model.context_parallel_size. Please use config.model_parallel.context_parallel_size instead." | |
| ) | |
| config.model_parallel.context_parallel_size = config.model.context_parallel_size | |
| parallel_state.initialize_model_parallel( | |
| pipeline_model_parallel_size=config.model_parallel.pipeline_model_parallel_size, | |
| tensor_model_parallel_size=config.model_parallel.tensor_model_parallel_size, | |
| context_parallel_size=config.model_parallel.context_parallel_size, | |
| ) | |
| # `config.model_parallel.sequence_parallel` is a bool that indicates whether to use sequence parallelism. | |
| # It is not part of the original `parallel_state` API, so we need to set it manually. | |
| parallel_state.sequence_parallel = config.model_parallel.sequence_parallel | |
| if parallel_state.sequence_parallel: | |
| os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" | |
| # Create the local job directory, save the config file, and pipe to a local log. | |
| if distributed.is_rank0(): | |
| os.makedirs(config.job.path_local, exist_ok=True) | |
| # Save the config as .pkl for reproducibility. | |
| LazyConfig.save_pkl(config, f"{config.job.path_local}/config.pkl") | |
| # Save the config as .yaml for reading or parsing experiment hyperparameters. | |
| LazyConfig.save_yaml(config, f"{config.job.path_local}/config.yaml") | |
| dist.barrier() | |
| log.init_loguru_file(f"{config.job.path_local}/stdout.log") | |
| if distributed.is_rank0(): | |
| # Print important environment variables and the effective config. | |
| log.info("Config:\n" + config.pretty_print(use_color=True)) | |
| misc.print_environ_variables(["TORCH_HOME", "OUTPUT_ROOT"]) | |
| # Set the random seed. If multi-GPU, different ranks are set with different seeds. | |
| misc.set_random_seed(seed=config.trainer.seed, by_rank=True) | |
| # Initialize cuDNN. | |
| torch.backends.cudnn.deterministic = config.trainer.cudnn.deterministic | |
| torch.backends.cudnn.benchmark = config.trainer.cudnn.benchmark | |
| # Floating-point precision settings. | |
| torch.backends.cudnn.allow_tf32 = torch.backends.cuda.matmul.allow_tf32 = True | |
| # Initialize the callback functions. | |
| self.callbacks = callback.CallBackGroup(config=config, trainer=self) | |
| # Initialize the model checkpointer. | |
| if config.checkpoint.type is None: | |
| self.checkpointer = Checkpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
| else: | |
| self.checkpointer: Checkpointer = instantiate( | |
| config.checkpoint.type, config.checkpoint, config.job, callbacks=self.callbacks | |
| ) | |
| # Initialize the timer for speed benchmarking. | |
| self.training_timer = misc.TrainingTimer() | |
| # Send a TimeoutError if a training step takes over timeout_period seconds. | |
| signal.signal(signal.SIGALRM, functools.partial(misc.timeout_handler, config.trainer.timeout_period)) # type: ignore | |
| def train( | |
| self, | |
| model: Model, | |
| dataloader_train: torch.utils.data.DataLoader, | |
| dataloader_val: torch.utils.data.DataLoader, | |
| ) -> None: | |
| """The training function. | |
| Args: | |
| model (Model): The PyTorch model. | |
| dataloader_train (torch.utils.data.DataLoader): The training data loader. | |
| dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
| """ | |
| # Leaving this for backward compability for now, but we can think about moving this to model.on_train_start for all models. | |
| model = model.to("cuda", memory_format=self.config.trainer.memory_format) # type: ignore | |
| model.on_train_start(self.config.trainer.memory_format) | |
| # Initialize the optimizer, scheduler, and grad_scaler. | |
| self.callbacks.on_optimizer_init_start() | |
| optimizer, scheduler = model.init_optimizer_scheduler(self.config.optimizer, self.config.scheduler) | |
| grad_scaler = torch.amp.GradScaler("cuda", **self.config.trainer.grad_scaler_args) | |
| self.callbacks.on_optimizer_init_end() | |
| # Load the model checkpoint and get the starting iteration number. | |
| iteration = self.checkpointer.load(model, optimizer, scheduler, grad_scaler) | |
| grad_accum_iter = 0 | |
| log.critical(f"Distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| # Create a DDP model wrapper. | |
| model_ddp = distributed.parallel_model_wrapper(self.config.trainer.ddp, model) | |
| elif self.config.trainer.distributed_parallelism == "fsdp": | |
| model_ddp = model | |
| else: | |
| raise ValueError(f"Unknown distributed parallelism mode: {self.config.trainer.distributed_parallelism}") | |
| log.info("Starting training...") | |
| self.callbacks.on_train_start(model, iteration=iteration) | |
| # Initial validation. | |
| if self.config.trainer.run_validation and iteration == 0: | |
| self.validate(model, dataloader_val, iteration=iteration) | |
| _end_training = False | |
| while True: | |
| dataloader_train_iter = iter(dataloader_train) | |
| while True: | |
| self.callbacks.on_before_dataloading(iteration) | |
| with self.training_timer("dataloader_train"): | |
| try: | |
| data_batch = next(dataloader_train_iter) | |
| for k in data_batch.keys(): | |
| if torch.is_tensor(data_batch[k]): | |
| data_batch[k] = data_batch[k].cuda() | |
| except StopIteration: | |
| break | |
| self.callbacks.on_after_dataloading(iteration) | |
| # If max_iter is reached, exit the training loop. | |
| if iteration >= self.config.trainer.max_iter: | |
| _end_training = True | |
| break | |
| # Move all tensors in the data batch to GPU device. | |
| data_batch = misc.to(data_batch, device="cuda") | |
| # The actual training step. | |
| self.callbacks.on_training_step_start(model, data_batch, iteration=iteration) | |
| if not model.training: | |
| model_ddp.train() | |
| assert model_ddp.training, "model_ddp is not in training mode." | |
| assert model.training, "model is not in training mode." | |
| output_batch, loss, grad_accum_iter = self.training_step( | |
| model_ddp, | |
| optimizer, | |
| scheduler, | |
| grad_scaler, | |
| data_batch, | |
| iteration=iteration, | |
| grad_accum_iter=grad_accum_iter, | |
| ) | |
| # Do the following when an actual optimizer (update) step has been made. | |
| iteration += 1 | |
| # Save checkpoint. | |
| if iteration % self.config.checkpoint.save_iter == 0: | |
| async_saving = getattr(self.config.checkpoint, "async_saving", True) | |
| self.checkpointer.save( | |
| model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving | |
| ) | |
| self.callbacks.on_training_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
| # Validation. | |
| if self.config.trainer.run_validation and iteration % self.config.trainer.validation_iter == 0: | |
| self.validate(model, dataloader_val, iteration=iteration) | |
| # This iteration is successful; reset the timeout signal. | |
| signal.alarm(self.config.trainer.timeout_period) | |
| if _end_training: | |
| break | |
| log.success("Done with training.") | |
| if iteration % self.config.checkpoint.save_iter != 0: | |
| async_saving = getattr(self.config.checkpoint, "async_saving", True) | |
| self.checkpointer.save( | |
| model, optimizer, scheduler, grad_scaler, iteration=iteration, async_saving=async_saving | |
| ) | |
| self.callbacks.on_train_end(model, iteration=iteration) | |
| self.checkpointer.finalize() | |
| distributed.barrier() | |
| self.callbacks.on_app_end() | |
| def training_step( | |
| self, | |
| model_ddp: torch.nn.Module | distributed.DistributedDataParallel, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| grad_scaler: torch.amp.GradScaler, | |
| data: dict[str, torch.Tensor], | |
| iteration: int = 0, | |
| grad_accum_iter: int = 0, | |
| ) -> tuple[dict[str, torch.Tensor], torch.Tensor, int]: | |
| """The training step. | |
| Args: | |
| model_ddp (torch.nn.Module | distributed.DistributedDataParallel): The model with a DDP wrapper or, the bare | |
| module, depending on whether distributed training is enabled or not. | |
| optimizer (torch.optim.Optimizer): The model optimizer. | |
| scheduler (torch.optim.lr_scheduler.LRScheduler): The optimization scheduler. | |
| grad_scaler (torch.amp.GradScaler): The gradient scaler (for mixed precision training). | |
| data (dict[str, torch.Tensor]): Data batch (dictionary of tensors). | |
| iteration (int): Current iteration number. | |
| grad_accum_iter (int): Number of gradient accumulation iterations. | |
| Returns: | |
| output (dict[str, torch.Tensor]): The model output from the training data batch (dictionary of tensors). | |
| loss (torch.Tensor): The total loss of the training data batch. | |
| """ | |
| # Only let DDP sync gradient at the last iteration of the gradient accumulation window | |
| with distributed.ddp_sync_grad(model_ddp, grad_accum_iter == self.config.trainer.grad_accum_iter - 1): | |
| with self.training_timer("forward"): | |
| output_batch, loss = model_ddp.training_step(data, iteration) | |
| self.callbacks.on_before_backward(model_ddp, loss, iteration=iteration) | |
| with self.training_timer("backward"): | |
| loss_scaled = grad_scaler.scale(loss / self.config.trainer.grad_accum_iter) | |
| loss_scaled.backward() | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| model_ddp.module.on_after_backward() | |
| else: | |
| model_ddp.on_after_backward() | |
| self.callbacks.on_after_backward(model_ddp, iteration=iteration) | |
| grad_accum_iter += 1 | |
| if grad_accum_iter == self.config.trainer.grad_accum_iter: | |
| with self.training_timer("optimizer_step"): | |
| self.callbacks.on_before_optimizer_step( | |
| model_ddp, optimizer, scheduler, grad_scaler, iteration=iteration | |
| ) | |
| grad_scaler.step(optimizer) | |
| grad_scaler.update() | |
| scheduler.step() | |
| self.callbacks.on_before_zero_grad(model_ddp, optimizer, scheduler, iteration=iteration) | |
| if self.config.trainer.distributed_parallelism == "ddp": | |
| model_ddp.module.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
| else: | |
| model_ddp.on_before_zero_grad(optimizer, scheduler, iteration=iteration) | |
| optimizer.zero_grad(set_to_none=True) | |
| grad_accum_iter = 0 | |
| return output_batch, loss, grad_accum_iter | |
| def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: | |
| """Validate on the full validation dataset. | |
| Args: | |
| model (Model): The PyTorch model. | |
| dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
| iteration (int): Current iteration number. | |
| """ | |
| self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) | |
| model.eval() | |
| # Evaluate on the full validation set. | |
| with ema.ema_scope(model, enabled=model.config.ema.enabled): | |
| for val_iter, data_batch in enumerate(dataloader_val): | |
| if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: | |
| break | |
| data_batch = misc.to(data_batch, device="cuda") | |
| self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) | |
| output_batch, loss = model.validation_step(data_batch, iteration) | |
| self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
| self.callbacks.on_validation_end(model, iteration=iteration) | |