xiangzai's picture
Add files using upload-large-folder tool
3e4f775 verified
import torch
from torch import nn
class Regularizer(nn.Module):
def __init__(self):
pass
def _batch_root_mean_squared(tensor):
tensor = tensor.view(tensor.shape[0], -1)
return torch.norm(tensor, p=2, dim=1) / tensor.shape[1] ** 0.5
class RegularizationFunc(nn.Module):
def forward(self, t, x, dx, context) -> torch.Tensor:
"""Outputs a batch of scaler regularizations."""
raise NotImplementedError
class L1Reg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
return torch.mean(torch.abs(dx), dim=1)
class L2Reg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
return _batch_root_mean_squared(dx)
class SquaredL2Reg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
to_return = dx.view(dx.shape[0], -1)
return torch.pow(torch.norm(to_return, p=2, dim=1), 2)
def _get_minibatch_jacobian(y, x, create_graph=True):
"""Computes the Jacobian of y wrt x assuming minibatch-mode.
Args:
y: (N, ...) with a total of D_y elements in ...
x: (N, ...) with a total of D_x elements in ...
Returns:
The minibatch Jacobian matrix of shape (N, D_y, D_x)
"""
# assert y.shape[0] == x.shape[0]
y = y.view(y.shape[0], -1)
# Compute Jacobian row by row.
jac = []
for j in range(y.shape[1]):
dy_j_dx = torch.autograd.grad(
y[:, j],
x,
torch.ones_like(y[:, j]),
retain_graph=True,
create_graph=create_graph,
)[0]
jac.append(torch.unsqueeze(dy_j_dx, -1))
jac = torch.cat(jac, -1)
return jac
class JacobianFrobeniusReg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
if hasattr(context, "jac"):
jac = context.jac
else:
jac = _get_minibatch_jacobian(dx, x)
context.jac = jac
jac = _get_minibatch_jacobian(dx, x)
context.jac = jac
return _batch_root_mean_squared(jac)
class JacobianDiagFrobeniusReg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
if hasattr(context, "jac"):
jac = context.jac
else:
jac = _get_minibatch_jacobian(dx, x)
context.jac = jac
diagonal = jac.view(jac.shape[0], -1)[
:, :: jac.shape[1]
] # assumes jac is minibatch square, ie. (N, M, M).
return _batch_root_mean_squared(diagonal)
class JacobianOffDiagFrobeniusReg(RegularizationFunc):
def forward(self, t, x, dx, context) -> torch.Tensor:
if hasattr(context, "jac"):
jac = context.jac
else:
jac = _get_minibatch_jacobian(dx, x)
context.jac = jac
diagonal = jac.view(jac.shape[0], -1)[
:, :: jac.shape[1]
] # assumes jac is minibatch square, ie. (N, M, M).
ss_offdiag = torch.sum(jac.view(jac.shape[0], -1) ** 2, dim=1) - torch.sum(
diagonal**2, dim=1
)
ms_offdiag = ss_offdiag / (diagonal.shape[1] * (diagonal.shape[1] - 1))
return ms_offdiag
def autograd_trace(x_out, x_in, **kwargs):
"""Standard brute-force means of obtaining trace of the Jacobian, O(d) calls to autograd."""
trJ = 0.0
for i in range(x_in.shape[1]):
trJ += torch.autograd.grad(x_out[:, i].sum(), x_in, allow_unused=False, create_graph=True)[
0
][:, i]
return trJ
class CNFReg(RegularizationFunc):
def __init__(self, trace_estimator=None, noise_dist=None):
super().__init__()
self.trace_estimator = trace_estimator if trace_estimator is not None else autograd_trace
self.noise_dist, self.noise = noise_dist, None
def forward(self, t, x, dx, context):
# TODO we could check if jac is in the context to speed up
return -self.trace_estimator(dx, x, noise=self.noise)
class AugmentationModule(nn.Module):
"""Class orchestrating augmentations.
Also establishes order.
"""
def __init__(
self,
cnf_estimator: str = None,
l1_reg: float = 0.0,
l2_reg: float = 0.0,
squared_l2_reg: float = 0.0,
jacobian_frobenius_reg: float = 0.0,
jacobian_diag_frobenius_reg: float = 0.0,
jacobian_off_diag_frobenius_reg: float = 0.0,
) -> None:
super().__init__()
coeffs = []
regs = []
if cnf_estimator == "exact":
coeffs.append(1)
regs.append(CNFReg(None, noise_dist=None))
if l1_reg > 0.0:
coeffs.append(l1_reg)
regs.append(L1Reg())
if l2_reg > 0.0:
coeffs.append(l2_reg)
regs.append(L2Reg())
if squared_l2_reg > 0.0:
coeffs.append(squared_l2_reg)
regs.append(SquaredL2Reg())
if jacobian_frobenius_reg > 0.0:
coeffs.append(jacobian_frobenius_reg)
regs.append(JacobianFrobeniusReg())
if jacobian_diag_frobenius_reg > 0.0:
coeffs.append(jacobian_diag_frobenius_reg)
regs.append(JacobianDiagFrobeniusReg())
if jacobian_off_diag_frobenius_reg > 0.0:
coeffs.append(jacobian_off_diag_frobenius_reg)
regs.append(JacobianOffDiagFrobeniusReg())
self.coeffs = torch.tensor(coeffs)
self.regs = torch.ModuleList(regs)
if __name__ == "__main__":
# Test Shapes
class SharedContext:
pass
for reg in [
L1Reg,
L2Reg,
SquaredL2Reg,
JacobianFrobeniusReg,
JacobianDiagFrobeniusReg,
JacobianOffDiagFrobeniusReg,
]:
x = torch.ones(2, 3).requires_grad_(True)
dx = x * 2
out = reg().forward(torch.ones(1), x, dx, SharedContext)
assert out.dim() == 1
assert out.shape[0] == 2