TextureScraping / models /week0417 /focal_frequency_loss.py
sunshineatnoon
Add application file
1b2a9b1
import torch
import torch.nn as nn
class FocalFrequencyLoss(nn.Module):
"""The torch.nn.Module class that implements focal frequency loss - a
frequency domain loss function for optimizing generative models.
Ref:
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021.
<https://arxiv.org/pdf/2012.12821.pdf>
Args:
loss_weight (float): weight for focal frequency loss. Default: 1.0
alpha (float): the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0
patch_factor (int): the factor to crop image patches for patch-based focal frequency loss. Default: 1
ave_spectrum (bool): whether to use minibatch average spectrum. Default: False
log_matrix (bool): whether to adjust the spectrum weight matrix by logarithm. Default: False
batch_matrix (bool): whether to calculate the spectrum weight matrix using batch-based statistics. Default: False
"""
def __init__(self, loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False):
super(FocalFrequencyLoss, self).__init__()
self.loss_weight = loss_weight
self.alpha = alpha
self.patch_factor = patch_factor
self.ave_spectrum = ave_spectrum
self.log_matrix = log_matrix
self.batch_matrix = batch_matrix
def tensor2freq(self, x):
# crop image patches
patch_factor = self.patch_factor
_, _, h, w = x.shape
assert h % patch_factor == 0 and w % patch_factor == 0, (
'Patch factor should be divisible by image height and width')
patch_list = []
patch_h = h // patch_factor
patch_w = w // patch_factor
for i in range(patch_factor):
for j in range(patch_factor):
patch_list.append(x[:, :, i * patch_h:(i + 1) * patch_h, j * patch_w:(j + 1) * patch_w])
# stack to patch tensor
y = torch.stack(patch_list, 1)
# perform 2D DFT (real-to-complex, orthonormalization)
return torch.rfft(y, 2, onesided=False, normalized=True)
def loss_formulation(self, recon_freq, real_freq, matrix=None):
# spectrum weight matrix
if matrix is not None:
# if the matrix is predefined
weight_matrix = matrix.detach()
else:
# if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
matrix_tmp = (recon_freq - real_freq) ** 2
matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha
# whether to adjust the spectrum weight matrix by logarithm
if self.log_matrix:
matrix_tmp = torch.log(matrix_tmp + 1.0)
# whether to calculate the spectrum weight matrix using batch-based statistics
if self.batch_matrix:
matrix_tmp = matrix_tmp / matrix_tmp.max()
else:
matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]
matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
weight_matrix = matrix_tmp.clone().detach()
assert weight_matrix.min().item() >= 0 and weight_matrix.max().item() <= 1, (
'The values of spectrum weight matrix should be in the range [0, 1], '
'but got Min: %.10f Max: %.10f' % (weight_matrix.min().item(), weight_matrix.max().item()))
# frequency distance using (squared) Euclidean distance
tmp = (recon_freq - real_freq) ** 2
freq_distance = tmp[..., 0] + tmp[..., 1]
# dynamic spectrum weighting (Hadamard product)
loss = weight_matrix * freq_distance
return torch.mean(loss)
def forward(self, pred, target, matrix=None, **kwargs):
"""Forward function to calculate focal frequency loss.
Args:
pred (torch.Tensor): of shape (N, C, H, W). Predicted tensor.
target (torch.Tensor): of shape (N, C, H, W). Target tensor.
matrix (torch.Tensor, optional): Element-wise spectrum weight matrix.
Default: None (If set to None: calculated online, dynamic).
"""
pred_freq = self.tensor2freq(pred)
target_freq = self.tensor2freq(target)
# whether to use minibatch average spectrum
if self.ave_spectrum:
pred_freq = torch.mean(pred_freq, 0, keepdim=True)
target_freq = torch.mean(target_freq, 0, keepdim=True)
# calculate focal frequency loss
return self.loss_formulation(pred_freq, target_freq, matrix) * self.loss_weight