roll-ai's picture
Upload 381 files
b6af722 verified
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""The combined loss functions for continuous-space tokenizers training."""
import numpy as np
import torch
import torch.nn as nn
from skimage.metrics import structural_similarity as ssim
from cosmos_predict1.tokenizer.modules.utils import time2batch
from cosmos_predict1.utils.lazy_config import instantiate
_VALID_METRIC_NAMES = ["PSNR", "SSIM", "CodeUsage"]
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max)
_FLOAT32_EPS = torch.finfo(torch.float32).eps
_RECONSTRUCTION = "reconstructions"
_QUANT_INFO = "quant_info"
class TokenizerMetric(nn.Module):
def __init__(self, config) -> None:
super().__init__()
self.config = config
self.metric_modules = nn.ModuleDict()
for key in _VALID_METRIC_NAMES:
self.metric_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NULLMetric()
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
metric = dict()
for _, module in self.metric_modules.items():
metric.update(module(inputs, output_batch, iteration))
return dict(metric=metric)
class NULLMetric(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
return dict()
class PSNRMetric(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
reconstructions = output_batch[_RECONSTRUCTION]
if inputs.ndim == 5:
inputs, _ = time2batch(inputs)
reconstructions, _ = time2batch(reconstructions)
# Normalize to uint8 [0..255] range.
true_image = (inputs.to(torch.float32) + 1) / 2
pred_image = (reconstructions.to(torch.float32) + 1) / 2
true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8)
pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8)
# Calculate PNSR, based on Mean Squared Error (MSE)
true_image = true_image.to(torch.float32)
pred_image = pred_image.to(torch.float32)
mse = torch.mean((true_image - pred_image) ** 2, dim=(1, 2, 3))
psnr = 10 * torch.log10(_UINT8_MAX_F**2 / (mse + _FLOAT32_EPS))
return dict(PSNR=torch.mean(psnr))
class SSIMMetric(torch.nn.Module):
def __init__(self) -> None:
super().__init__()
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
reconstructions = output_batch[_RECONSTRUCTION]
if inputs.ndim == 5:
inputs, _ = time2batch(inputs)
reconstructions, _ = time2batch(reconstructions)
# Normalize to uint8 [0..255] range.
true_image = (inputs.to(torch.float32) + 1) / 2
pred_image = (reconstructions.to(torch.float32) + 1) / 2
true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8)
pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8)
# Move tensors to CPU and convert to numpy arrays
true_image_np = true_image.permute(0, 2, 3, 1).cpu().numpy()
pred_image_np = pred_image.permute(0, 2, 3, 1).cpu().numpy()
# Calculate SSIM for each image in the batch and average over the batch
ssim_values = []
for true_image_i, pred_image_i in zip(true_image_np, pred_image_np):
ssim_value = ssim(true_image_i, pred_image_i, data_range=_UINT8_MAX_F, multichannel=True, channel_axis=-1)
ssim_values.append(ssim_value)
ssim_mean = np.mean(ssim_values)
return dict(SSIM=torch.tensor(ssim_mean, dtype=torch.float32, device=inputs.device))
class CodeUsageMetric(torch.nn.Module):
"""
Calculate the perplexity of codebook usage (only for discrete tokenizers)
:param codebook_indices: Tensor of codebook indices (quant_info)
:param codebook_size: The total number of codebook entries
:return: Perplexity of the codebook usage
"""
def __init__(self, codebook_size: int) -> None:
super().__init__()
self.codebook_size = codebook_size
def forward(
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int
) -> dict[str, torch.Tensor]:
code_indices = output_batch[_QUANT_INFO]
usage_counts = torch.bincount(code_indices.flatten().int(), minlength=self.codebook_size)
total_usage = usage_counts.sum().float()
usage_probs = usage_counts.float() / total_usage
entropy = -torch.sum(usage_probs * torch.log(usage_probs + _FLOAT32_EPS))
perplexity = torch.exp(entropy)
return dict(CodeUsage=torch.tensor(perplexity, dtype=torch.float32, device=code_indices.device))