Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| from __future__ import annotations | |
| import os | |
| import threading | |
| import torch | |
| from torch._dynamo.eval_frame import OptimizedModule as torch_OptimizedModule | |
| from cosmos_predict1.utils import callback, distributed, ema, log, misc | |
| from cosmos_predict1.utils.checkpointer import Checkpointer | |
| from cosmos_predict1.utils.config import CheckpointConfig, JobConfig | |
| from cosmos_predict1.utils.model import Model | |
| class TokenizerCheckpointer(Checkpointer): | |
| """The tokenizer checkpointer, extends the shared checkpointer. | |
| Supports checkpoint saving/loading to local disk: | |
| - network weights and training optimizer states. | |
| - optionally, export a TorchScript version of the EMA model. | |
| """ | |
| def __init__(self, config_checkpoint: CheckpointConfig, config_job: JobConfig, callbacks: callback.CallBackGroup): | |
| super().__init__(config_checkpoint, config_job, callbacks) | |
| self.callbacks = callbacks | |
| self.config_jit = config_checkpoint.jit | |
| def save( | |
| self, | |
| model: Model, | |
| optimizer: torch.optim.Optimizer, | |
| scheduler: torch.optim.lr_scheduler.LRScheduler, | |
| grad_scaler: torch.amp.GradScaler, | |
| iteration: int = -1, | |
| **ignore_kwargs, | |
| ) -> None: | |
| """Saves network weights, optimizer parameters, scheduler parameters to a checkpoint. | |
| Args: | |
| model (Model): The PyTorch model. | |
| optimizer: The model optimizer. | |
| scheduler: The optimization scheduler. | |
| grad_scaler: The gradient scaler (for mixed precision training). | |
| iteration: Current iteration number. | |
| """ | |
| self.callbacks.on_save_checkpoint_start(model, iteration) | |
| model.eval() | |
| checkpoint_file = f"iter_{iteration:09}.pt" | |
| if distributed.get_rank() == 0: | |
| state_dict = dict( | |
| model=model.state_dict(), | |
| optimizer=optimizer.state_dict(), | |
| scheduler=scheduler.state_dict(), | |
| grad_scaler=grad_scaler.state_dict(), | |
| iteration=iteration, | |
| ) | |
| state_dict = misc.to(state_dict, device="cpu") | |
| self.callbacks.on_save_checkpoint(model, state_dict=state_dict) | |
| # Wait for previous saver thread to end. | |
| if self.save_thread: | |
| self.save_thread.join() | |
| # Run the checkpoint saver in a separate thread. | |
| self.save_thread = threading.Thread( | |
| target=self._save_worker_local, | |
| daemon=False, | |
| args=(state_dict, self._get_ema_jit(model), checkpoint_file, distributed.get_rank()), | |
| ) | |
| self.save_thread.start() | |
| # Note: Checkpoints are saved on a separate thread and this callback is not accurate. | |
| # Please check logs from on_save_checkpoint_success() for better accuracy | |
| self.callbacks.on_save_checkpoint_end(model=None, iteration=iteration) | |
| def _save_worker_local( | |
| self, | |
| state_dict: dict[str, torch.Tensor], | |
| jit_models: dict[str, torch.ScriptModule], | |
| checkpoint_file: str, | |
| rank: int = 0, | |
| ) -> None: | |
| """Worker to save checkpoint to local disk, spawned with a child thread (runs in parallel with the training). | |
| Args: | |
| state_dict: The state dict of the model/optimizer/scheduler. | |
| ema_jit: A dict of TorchScript EMA model, representing the encoder, decoder and full model. | |
| checkpoint_file (str): The file name of the model checkpoint. | |
| rank (int): GPU device (default: 0). | |
| """ | |
| checkpoint_path = os.path.join(self.checkpoint_dir_local, checkpoint_file) | |
| os.makedirs(self.checkpoint_dir_local, exist_ok=True) | |
| try: | |
| torch.save(state_dict, checkpoint_path) | |
| for key, jit_model in jit_models.items(): | |
| checkpoint_jit = checkpoint_path.replace(".pt", f"_{key}.jit") | |
| torch.jit.save(jit_model, checkpoint_jit) | |
| log.success(f"Saved checkpoint: {checkpoint_jit}") | |
| if rank == 0: | |
| self._write_latest_checkpoint_file(checkpoint_file) | |
| log.success(f"Saved checkpoint (local): {checkpoint_path}") | |
| iteration = int(checkpoint_file.replace("iter_", "").replace(".pt", "")) | |
| self.callbacks.on_save_checkpoint_success(iteration=iteration) | |
| except Exception as e: # noqa: BLE001 | |
| log.exception(f"Checkpoint failed to save (local): {e}") | |
| def _get_ema_jit(self, model: Model) -> dict[str, torch.ScriptModule]: | |
| """Returns a TorchScript version of ema models compiled by PyTorch JIT.""" | |
| if not self.config_jit.enabled: | |
| return dict() | |
| input_shape = tuple(self.config_jit.input_shape) | |
| example_input = torch.randn(input_shape) | |
| dtype = getattr(torch, self.config_jit.dtype) | |
| example_input = example_input.to(self.config_jit.device).to(dtype) | |
| with ema.ema_scope(model, enabled=model.config.ema.enabled): | |
| _model = model.network | |
| if isinstance(_model, torch_OptimizedModule): | |
| _model = _model._orig_mod | |
| # Make sure jit model output consistenly during consecutive calls | |
| # Check here: https://github.com/pytorch/pytorch/issues/74534 | |
| torch._C._jit_set_texpr_fuser_enabled(False) | |
| ema_jit = torch.jit.trace(_model, example_input, strict=self.config_jit.strict) | |
| encoder_jit = torch.jit.trace(_model.encoder_jit(), example_input, strict=self.config_jit.strict) | |
| decoder_example = encoder_jit(example_input) | |
| if isinstance(decoder_example, tuple): | |
| decoder_example = decoder_example[0] | |
| else: | |
| assert isinstance(decoder_example, torch.Tensor), "decoder_example should be a tensor or tuple" | |
| decoder_jit = torch.jit.trace(_model.decoder_jit(), decoder_example, strict=self.config_jit.strict) | |
| return {"ema": ema_jit, "enc": encoder_jit, "dec": decoder_jit} | |