Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
class FocalFrequencyLoss(nn.Module): | |
"""The torch.nn.Module class that implements focal frequency loss - a | |
frequency domain loss function for optimizing generative models. | |
Ref: | |
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. | |
<https://arxiv.org/pdf/2012.12821.pdf> | |
Args: | |
loss_weight (float): weight for focal frequency loss. Default: 1.0 | |
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 | |
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1 | |
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False | |
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False | |
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False | |
""" | |
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False): | |
super(FocalFrequencyLoss, self).__init__() | |
self.loss_weight = loss_weight | |
self.alpha = alpha | |
self.patch_factor = patch_factor | |
self.ave_spectrum = ave_spectrum | |
self.log_matrix = log_matrix | |
self.batch_matrix = batch_matrix | |
def tensor2freq(self, x): | |
# crop image patches | |
patch_factor = self.patch_factor | |
_, _, h, w = x.shape | |
assert h % patch_factor == 0 and w % patch_factor == 0, ( | |
'Patch factor should be divisible by image height and width') | |
patch_list = [] | |
patch_h = h // patch_factor | |
patch_w = w // patch_factor | |
for i in range(patch_factor): | |
for j in range(patch_factor): | |
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w]) | |
# stack to patch tensor | |
y = torch.stack(patch_list, 1) | |
# perform 2D DFT (real-to-complex, orthonormalization) | |
return torch.rfft(y, 2, onesided=False, normalized=True) | |
def loss_formulation(self, recon_freq, real_freq, matrix=None): | |
# spectrum weight matrix | |
if matrix is not None: | |
# if the matrix is predefined | |
weight_matrix = matrix.detach() | |
else: | |
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance | |
matrix_tmp = (recon_freq - real_freq) ** 2 | |
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha | |
# whether to adjust the spectrum weight matrix by logarithm | |
if self.log_matrix: | |
matrix_tmp = torch.log(matrix_tmp + 1.0) | |
# whether to calculate the spectrum weight matrix using batch-based statistics | |
if self.batch_matrix: | |
matrix_tmp = matrix_tmp / matrix_tmp.max() | |
else: | |
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None] | |
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0 | |
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0) | |
weight_matrix = matrix_tmp.clone().detach() | |
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, ( | |
'The values of spectrum weight matrix should be in the range [0, 1], ' | |
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item())) | |
# frequency distance using (squared) Euclidean distance | |
tmp = (recon_freq - real_freq) ** 2 | |
freq_distance = tmp[..., 0] + tmp[..., 1] | |
# dynamic spectrum weighting (Hadamard product) | |
loss = weight_matrix * freq_distance | |
return torch.mean(loss) | |
def forward(self, pred, target, matrix=None, **kwargs): | |
"""Forward function to calculate focal frequency loss. | |
Args: | |
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor. | |
target (torch.Tensor): of shape (N, C, H, W). Target tensor. | |
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix. | |
Default: None (If set to None: calculated online, dynamic). | |
""" | |
pred_freq = self.tensor2freq(pred) | |
target_freq = self.tensor2freq(target) | |
# whether to use minibatch average spectrum | |
if self.ave_spectrum: | |
pred_freq = torch.mean(pred_freq, 0, keepdim=True) | |
target_freq = torch.mean(target_freq, 0, keepdim=True) | |
# calculate focal frequency loss | |
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight | |