|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
EPS = 1e-12 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def watch_max( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].abs().max()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].abs().max()) |
|
|
else: |
|
|
return float(tensor[mask].abs().max()) |
|
|
|
|
|
def watch_min( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].abs().min()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].abs().min()) |
|
|
else: |
|
|
return float(tensor[mask].abs().min()) |
|
|
|
|
|
def watch_mean( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].mean()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].mean()) |
|
|
else: |
|
|
return float(tensor[mask].mean()) |
|
|
|
|
|
def watch_var( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].var()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].var()) |
|
|
else: |
|
|
return float(tensor[mask].var()) |
|
|
|
|
|
def watch_std( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].std()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].std()) |
|
|
else: |
|
|
return float(tensor[mask].std()) |
|
|
|
|
|
def watch_sparsity( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
sparsity_threshold: float = 1e-6, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float((tensor.grad[mask].abs() <= sparsity_threshold).float().mean()) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float((tensor.data[mask].abs() <= sparsity_threshold).float().mean()) |
|
|
else: |
|
|
return float((tensor[mask].abs() <= sparsity_threshold).float().mean()) |
|
|
|
|
|
def watch_l1( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].norm(p=1)) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].norm(p=1)) |
|
|
else: |
|
|
return float(tensor[mask].norm(p=1)) |
|
|
|
|
|
def watch_l2( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask].norm(p=2)) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask].norm(p=2)) |
|
|
else: |
|
|
return float(tensor[mask].norm(p=2)) |
|
|
|
|
|
def watch_snr( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> None | float: |
|
|
std = watch_std(tensor, mask, grad=grad) |
|
|
if std <= 0: |
|
|
return None |
|
|
elif grad: |
|
|
val = float(torch.log10((tensor.grad[mask].mean()).abs() / (std + EPS))) |
|
|
elif hasattr(tensor, 'data'): |
|
|
val = float(torch.log10((tensor.data[mask].mean()).abs() / (std + EPS))) |
|
|
else: |
|
|
val = float(torch.log10((tensor[mask].mean()).abs() / (std + EPS))) |
|
|
return 20 * val if val != float("-inf") else None |
|
|
|
|
|
def watch_hist( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> torch.Tensor: |
|
|
if grad: |
|
|
return tensor.grad[mask] |
|
|
elif hasattr(tensor, 'data'): |
|
|
return tensor.data[mask] |
|
|
else: |
|
|
return tensor[mask] |
|
|
|
|
|
def watch_rank( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
threshold: float = 0.92, |
|
|
) -> None | float | int: |
|
|
if grad: |
|
|
work_tensor = tensor.grad |
|
|
elif hasattr(tensor, 'data'): |
|
|
work_tensor = tensor.data |
|
|
else: |
|
|
work_tensor = tensor |
|
|
work_tensor = torch.multiply(work_tensor, mask.float()) |
|
|
|
|
|
if work_tensor.ndim < 2: |
|
|
return None |
|
|
else: |
|
|
|
|
|
work_tensor = torch.linalg.svdvals(work_tensor) |
|
|
work_tensor = torch.sort(work_tensor, descending=True).values |
|
|
|
|
|
work_tensor = torch.cumsum(work_tensor**2, dim=0) / (torch.sum(work_tensor**2) + EPS) |
|
|
|
|
|
return float(torch.sum(work_tensor < threshold).item() + 1) |
|
|
|
|
|
def watch_any( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(tensor.grad[mask]) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(tensor.data[mask]) |
|
|
else: |
|
|
return float(tensor[mask]) |
|
|
|
|
|
def watch_power( |
|
|
tensor: torch.Tensor, |
|
|
mask: torch.Tensor, |
|
|
grad: bool = False, |
|
|
) -> float: |
|
|
if grad: |
|
|
return float(10 * torch.log10((tensor.grad[mask] ** 2).mean() + EPS)) |
|
|
elif hasattr(tensor, 'data'): |
|
|
return float(10 * torch.log10((tensor.data[mask] ** 2).mean() + EPS)) |
|
|
else: |
|
|
return float(10 * torch.log10((tensor[mask] ** 2).mean() + EPS)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
REG_FUNCTION_MAP = { |
|
|
|
|
|
'max': watch_max, |
|
|
'min': watch_min, |
|
|
'mean': watch_mean, |
|
|
'std': watch_std, |
|
|
'var': watch_var, |
|
|
'l2': watch_l2, |
|
|
'l1': watch_l1, |
|
|
'sparsity': watch_sparsity, |
|
|
'snr': watch_snr, |
|
|
'hist': watch_hist, |
|
|
'rank': watch_rank, |
|
|
'power': watch_power, |
|
|
|
|
|
|
|
|
'grad_max': lambda x, y: watch_max(x, y, grad=True), |
|
|
'grad_min': lambda x, y: watch_min(x, y, grad=True), |
|
|
'grad_mean': lambda x, y: watch_mean(x, y, grad=True), |
|
|
'grad_std': lambda x, y: watch_std(x, y, grad=True), |
|
|
'grad_var': lambda x, y: watch_var(x, y, grad=True), |
|
|
'grad_l1': lambda x, y: watch_l1(x, y, grad=True), |
|
|
'grad_l2': lambda x, y: watch_l2(x, y, grad=True), |
|
|
'grad_sparsity': lambda x, y: watch_sparsity(x, y, grad=True), |
|
|
'grad_snr': lambda x, y: watch_snr(x, y, grad=True), |
|
|
'grad_hist': lambda x, y: watch_hist(x, y, grad=True), |
|
|
'grad_rank': lambda x, y: watch_rank(x, y, grad=True), |
|
|
'grad_power': lambda x, y: watch_power(x, y, grad=True), |
|
|
|
|
|
|
|
|
'loss': watch_any, |
|
|
'val_loss': watch_any, |
|
|
'lr': watch_any |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|