# Copyright (C) 2024-present Naver Corporation. All rights reserved. # Licensed under CC BY-NC-SA 4.0 (non-commercial use only). # # -------------------------------------------------------- # losses for sparse ga # -------------------------------------------------------- import torch import numpy as np def l05_loss(x, y): return torch.linalg.norm(x - y, dim=-1).sqrt() def l1_loss(x, y): return torch.linalg.norm(x - y, dim=-1) def gamma_loss(gamma, mul=1, offset=None, clip=np.inf): if offset is None: if gamma == 1: return l1_loss # d(x**p)/dx = 1 ==> p * x**(p-1) == 1 ==> x = (1/p)**(1/(p-1)) offset = (1 / gamma)**(1 / (gamma - 1)) def loss_func(x, y): return (mul * l1_loss(x, y).clip(max=clip) + offset) ** gamma - offset ** gamma return loss_func def meta_gamma_loss(): return lambda alpha: gamma_loss(alpha)