Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import torch | |
from mmengine.model import BaseModule | |
from mmpretrain.registry import MODELS | |
class CrossCorrelationLoss(BaseModule): | |
"""Cross correlation loss function. | |
Compute the on-diagnal and off-diagnal loss. | |
Args: | |
lambd (float): The weight for the off-diag loss. | |
""" | |
def __init__(self, lambd: float = 0.0051) -> None: | |
super().__init__() | |
self.lambd = lambd | |
def forward(self, cross_correlation_matrix: torch.Tensor) -> torch.Tensor: | |
"""Forward function of cross correlation loss. | |
Args: | |
cross_correlation_matrix (torch.Tensor): The cross correlation | |
matrix. | |
Returns: | |
torch.Tensor: cross correlation loss. | |
""" | |
# loss | |
on_diag = torch.diagonal(cross_correlation_matrix).add_(-1).pow_( | |
2).sum() | |
off_diag = self.off_diagonal(cross_correlation_matrix).pow_(2).sum() | |
loss = on_diag + self.lambd * off_diag | |
return loss | |
def off_diagonal(self, x: torch.Tensor) -> torch.Tensor: | |
"""Rreturn a flattened view of the off-diagonal elements of a square | |
matrix.""" | |
n, m = x.shape | |
assert n == m | |
return x.flatten()[:-1].view(n - 1, n + 1)[:, 1:].flatten() | |