|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""The combined loss functions for continuous-space tokenizers training.""" |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from skimage.metrics import structural_similarity as ssim |
|
|
|
|
|
from cosmos_predict1.tokenizer.modules.utils import time2batch |
|
|
from cosmos_predict1.utils.lazy_config import instantiate |
|
|
|
|
|
_VALID_METRIC_NAMES = ["PSNR", "SSIM", "CodeUsage"] |
|
|
_UINT8_MAX_F = float(torch.iinfo(torch.uint8).max) |
|
|
_FLOAT32_EPS = torch.finfo(torch.float32).eps |
|
|
_RECONSTRUCTION = "reconstructions" |
|
|
_QUANT_INFO = "quant_info" |
|
|
|
|
|
|
|
|
class TokenizerMetric(nn.Module): |
|
|
def __init__(self, config) -> None: |
|
|
super().__init__() |
|
|
self.config = config |
|
|
self.metric_modules = nn.ModuleDict() |
|
|
for key in _VALID_METRIC_NAMES: |
|
|
self.metric_modules[key] = instantiate(getattr(config, key)) if hasattr(config, key) else NULLMetric() |
|
|
|
|
|
def forward( |
|
|
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> dict[str, torch.Tensor]: |
|
|
metric = dict() |
|
|
for _, module in self.metric_modules.items(): |
|
|
metric.update(module(inputs, output_batch, iteration)) |
|
|
return dict(metric=metric) |
|
|
|
|
|
|
|
|
class NULLMetric(torch.nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
def forward( |
|
|
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> dict[str, torch.Tensor]: |
|
|
return dict() |
|
|
|
|
|
|
|
|
class PSNRMetric(torch.nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
def forward( |
|
|
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> dict[str, torch.Tensor]: |
|
|
reconstructions = output_batch[_RECONSTRUCTION] |
|
|
if inputs.ndim == 5: |
|
|
inputs, _ = time2batch(inputs) |
|
|
reconstructions, _ = time2batch(reconstructions) |
|
|
|
|
|
|
|
|
true_image = (inputs.to(torch.float32) + 1) / 2 |
|
|
pred_image = (reconstructions.to(torch.float32) + 1) / 2 |
|
|
true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) |
|
|
pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) |
|
|
|
|
|
|
|
|
true_image = true_image.to(torch.float32) |
|
|
pred_image = pred_image.to(torch.float32) |
|
|
mse = torch.mean((true_image - pred_image) ** 2, dim=(1, 2, 3)) |
|
|
psnr = 10 * torch.log10(_UINT8_MAX_F**2 / (mse + _FLOAT32_EPS)) |
|
|
return dict(PSNR=torch.mean(psnr)) |
|
|
|
|
|
|
|
|
class SSIMMetric(torch.nn.Module): |
|
|
def __init__(self) -> None: |
|
|
super().__init__() |
|
|
|
|
|
def forward( |
|
|
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> dict[str, torch.Tensor]: |
|
|
reconstructions = output_batch[_RECONSTRUCTION] |
|
|
if inputs.ndim == 5: |
|
|
inputs, _ = time2batch(inputs) |
|
|
reconstructions, _ = time2batch(reconstructions) |
|
|
|
|
|
|
|
|
true_image = (inputs.to(torch.float32) + 1) / 2 |
|
|
pred_image = (reconstructions.to(torch.float32) + 1) / 2 |
|
|
true_image = (true_image * _UINT8_MAX_F + 0.5).to(torch.uint8) |
|
|
pred_image = (pred_image * _UINT8_MAX_F + 0.5).to(torch.uint8) |
|
|
|
|
|
|
|
|
true_image_np = true_image.permute(0, 2, 3, 1).cpu().numpy() |
|
|
pred_image_np = pred_image.permute(0, 2, 3, 1).cpu().numpy() |
|
|
|
|
|
|
|
|
ssim_values = [] |
|
|
for true_image_i, pred_image_i in zip(true_image_np, pred_image_np): |
|
|
ssim_value = ssim(true_image_i, pred_image_i, data_range=_UINT8_MAX_F, multichannel=True, channel_axis=-1) |
|
|
ssim_values.append(ssim_value) |
|
|
ssim_mean = np.mean(ssim_values) |
|
|
return dict(SSIM=torch.tensor(ssim_mean, dtype=torch.float32, device=inputs.device)) |
|
|
|
|
|
|
|
|
class CodeUsageMetric(torch.nn.Module): |
|
|
""" |
|
|
Calculate the perplexity of codebook usage (only for discrete tokenizers) |
|
|
|
|
|
:param codebook_indices: Tensor of codebook indices (quant_info) |
|
|
:param codebook_size: The total number of codebook entries |
|
|
:return: Perplexity of the codebook usage |
|
|
""" |
|
|
|
|
|
def __init__(self, codebook_size: int) -> None: |
|
|
super().__init__() |
|
|
self.codebook_size = codebook_size |
|
|
|
|
|
def forward( |
|
|
self, inputs: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor], iteration: int |
|
|
) -> dict[str, torch.Tensor]: |
|
|
code_indices = output_batch[_QUANT_INFO] |
|
|
usage_counts = torch.bincount(code_indices.flatten().int(), minlength=self.codebook_size) |
|
|
total_usage = usage_counts.sum().float() |
|
|
usage_probs = usage_counts.float() / total_usage |
|
|
entropy = -torch.sum(usage_probs * torch.log(usage_probs + _FLOAT32_EPS)) |
|
|
perplexity = torch.exp(entropy) |
|
|
return dict(CodeUsage=torch.tensor(perplexity, dtype=torch.float32, device=code_indices.device)) |
|
|
|