Spaces:
Build error
Build error
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: Apache-2.0 | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
import torch | |
import torch.utils.data | |
from cosmos_predict1.tokenizer.training.checkpointer import TokenizerCheckpointer | |
from cosmos_predict1.utils import ema, misc | |
from cosmos_predict1.utils.model import Model | |
from cosmos_predict1.utils.trainer import Trainer | |
class TokenizerTrainer(Trainer): | |
"""The tokenizers traine, extended from Trainer. | |
It extends model training functionality. | |
Attributes: | |
checkpointer (Checkpointer): checkpointer object to save/load model weights and optimizer states. | |
training_timer (misc.Timer): Timer object to time code blocks and functions. | |
""" | |
def __init__(self, config): | |
super(TokenizerTrainer, self).__init__(config) | |
self.model_config = config.model.config | |
self.checkpointer = TokenizerCheckpointer(config.checkpoint, config.job, callbacks=self.callbacks) | |
def validate(self, model: Model, dataloader_val: torch.utils.data.DataLoader, iteration: int = 0) -> None: | |
"""Validate on the full validation dataset. | |
Args: | |
model (Model): The PyTorch model. | |
dataloader_val (torch.utils.data.DataLoader): The validation data loader. | |
iteration (int): Current iteration number. | |
""" | |
self.callbacks.on_validation_start(model, dataloader_val, iteration=iteration) | |
model.eval() | |
# Evaluate on the full validation set. | |
for val_iter, data_batch in enumerate(dataloader_val): | |
if self.config.trainer.max_val_iter is not None and val_iter >= self.config.trainer.max_val_iter: | |
break | |
data_batch = misc.to(data_batch, device="cuda") | |
self.callbacks.on_validation_step_start(model, data_batch, iteration=iteration) | |
output_batch, _ = model.validation_step(data_batch, iteration) | |
with ema.ema_scope(model, enabled=model.config.ema.enabled): | |
ema_output_batch, loss = model.validation_step(data_batch, iteration, ema_model=True) | |
output_batch.update(ema_output_batch) | |
self.callbacks.on_validation_step_end(model, data_batch, output_batch, loss, iteration=iteration) | |
self.callbacks.on_validation_end(model, iteration=iteration) | |