Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """The combined loss functions for continuous-space tokenizers training.""" | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| from skimage.metrics import structural_similarity as ssim | |
| from cosmos_predict1.tokenizer.modules.utils import time2batch | |
| from cosmos_predict1.utils.lazy_config import instantiate | |
| _VALID_METRIC_NAMES = ["PSNR", "SSIM", "CodeUsage"] | |
| _UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) | |
| _FLOAT32_EPS = torch.finfo(torch.float32).eps | |
| _RECONSTRUCTION = "reconstructions" | |
| _QUANT_INFO = "quant_info" | |
| class TokenizerMetric(nn.Module): | |
| def __init__(self, config) -> None: | |
| super().__init__() | |
| self.config = config | |
| self.metric_modules = nn.ModuleDict() | |
| for key in _VALID_METRIC_NAMES: | |
| self.metric_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NULLMetric() | |
| def forward( | |
| self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> dict[str, torch.Tensor]: | |
| metric = dict() | |
| for _, module in self.metric_modules.items(): | |
| metric.update(module(inputs, output_batch, iteration)) | |
| return dict(metric=metric) | |
| class NULLMetric(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward( | |
| self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> dict[str, torch.Tensor]: | |
| return dict() | |
| class PSNRMetric(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward( | |
| self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> dict[str, torch.Tensor]: | |
| reconstructions = output_batch[_RECONSTRUCTION] | |
| if inputs.ndim == 5: | |
| inputs, _ = time2batch(inputs) | |
| reconstructions, _ = time2batch(reconstructions) | |
| # Normalize to uint8 [0..255] range. | |
| true_image = (inputs.to(torch.float32) + 1) / 2 | |
| pred_image = (reconstructions.to(torch.float32) + 1) / 2 | |
| true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) | |
| pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) | |
| # Calculate PNSR, based on Mean Squared Error (MSE) | |
| true_image = true_image.to(torch.float32) | |
| pred_image = pred_image.to(torch.float32) | |
| mse = torch.mean((true_image - pred_image) ** 2, dim=(1, 2, 3)) | |
| psnr = 10 * torch.log10(_UINT8_MAX_F**2 / (mse + _FLOAT32_EPS)) | |
| return dict(PSNR=torch.mean(psnr)) | |
| class SSIMMetric(torch.nn.Module): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def forward( | |
| self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> dict[str, torch.Tensor]: | |
| reconstructions = output_batch[_RECONSTRUCTION] | |
| if inputs.ndim == 5: | |
| inputs, _ = time2batch(inputs) | |
| reconstructions, _ = time2batch(reconstructions) | |
| # Normalize to uint8 [0..255] range. | |
| true_image = (inputs.to(torch.float32) + 1) / 2 | |
| pred_image = (reconstructions.to(torch.float32) + 1) / 2 | |
| true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) | |
| pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) | |
| # Move tensors to CPU and convert to numpy arrays | |
| true_image_np = true_image.permute(0, 2, 3, 1).cpu().numpy() | |
| pred_image_np = pred_image.permute(0, 2, 3, 1).cpu().numpy() | |
| # Calculate SSIM for each image in the batch and average over the batch | |
| ssim_values = [] | |
| for true_image_i, pred_image_i in zip(true_image_np, pred_image_np): | |
| ssim_value = ssim(true_image_i, pred_image_i, data_range=_UINT8_MAX_F, multichannel=True, channel_axis=-1) | |
| ssim_values.append(ssim_value) | |
| ssim_mean = np.mean(ssim_values) | |
| return dict(SSIM=torch.tensor(ssim_mean, dtype=torch.float32, device=inputs.device)) | |
| class CodeUsageMetric(torch.nn.Module): | |
| """ | |
| Calculate the perplexity of codebook usage (only for discrete tokenizers) | |
| :param codebook_indices: Tensor of codebook indices (quant_info) | |
| :param codebook_size: The total number of codebook entries | |
| :return: Perplexity of the codebook usage | |
| """ | |
| def __init__(self, codebook_size: int) -> None: | |
| super().__init__() | |
| self.codebook_size = codebook_size | |
| def forward( | |
| self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int | |
| ) -> dict[str, torch.Tensor]: | |
| code_indices = output_batch[_QUANT_INFO] | |
| usage_counts = torch.bincount(code_indices.flatten().int(), minlength=self.codebook_size) | |
| total_usage = usage_counts.sum().float() | |
| usage_probs = usage_counts.float() / total_usage | |
| entropy = -torch.sum(usage_probs * torch.log(usage_probs + _FLOAT32_EPS)) | |
| perplexity = torch.exp(entropy) | |
| return dict(CodeUsage=torch.tensor(perplexity, dtype=torch.float32, device=code_indices.device)) | |