alverciito
upload safetensors and refactor research files
dbd79bd
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# #
# This file was created by: Alberto Palomo Alonso #
# Universidad de Alcalá - Escuela Politécnica Superior #
# #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# Import statements:
import torch
EPS = 1e-12
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
# REGISTER #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
def watch_max(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].abs().max())
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].abs().max())
else:
return float(tensor[mask].abs().max())
def watch_min(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].abs().min())
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].abs().min())
else:
return float(tensor[mask].abs().min())
def watch_mean(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].mean())
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].mean())
else:
return float(tensor[mask].mean())
def watch_var(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].var())
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].var())
else:
return float(tensor[mask].var())
def watch_std(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].std())
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].std())
else:
return float(tensor[mask].std())
def watch_sparsity(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
sparsity_threshold: float = 1e-6,
) -> float:
if grad:
return float((tensor.grad[mask].abs() <= sparsity_threshold).float().mean())
elif hasattr(tensor, 'data'):
return float((tensor.data[mask].abs() <= sparsity_threshold).float().mean())
else:
return float((tensor[mask].abs() <= sparsity_threshold).float().mean())
def watch_l1(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].norm(p=1))
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].norm(p=1))
else:
return float(tensor[mask].norm(p=1))
def watch_l2(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask].norm(p=2))
elif hasattr(tensor, 'data'):
return float(tensor.data[mask].norm(p=2))
else:
return float(tensor[mask].norm(p=2))
def watch_snr(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> None | float:
std = watch_std(tensor, mask, grad=grad)
if std <= 0:
return None
elif grad:
val = float(torch.log10((tensor.grad[mask].mean()).abs() / (std + EPS)))
elif hasattr(tensor, 'data'):
val = float(torch.log10((tensor.data[mask].mean()).abs() / (std + EPS)))
else:
val = float(torch.log10((tensor[mask].mean()).abs() / (std + EPS)))
return 20 * val if val != float("-inf") else None # Check for NaN
def watch_hist(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> torch.Tensor:
if grad:
return tensor.grad[mask]
elif hasattr(tensor, 'data'):
return tensor.data[mask]
else:
return tensor[mask]
def watch_rank(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
threshold: float = 0.92,
) -> None | float | int:
if grad:
work_tensor = tensor.grad
elif hasattr(tensor, 'data'):
work_tensor = tensor.data
else:
work_tensor = tensor
work_tensor = torch.multiply(work_tensor, mask.float())
if work_tensor.ndim < 2:
return None
else:
# Compute SVD and sort it:
work_tensor = torch.linalg.svdvals(work_tensor)
work_tensor = torch.sort(work_tensor, descending=True).values
# Cumulative energy:
work_tensor = torch.cumsum(work_tensor**2, dim=0) / (torch.sum(work_tensor**2) + EPS)
# Effective rank:
return float(torch.sum(work_tensor < threshold).item() + 1)
def watch_any(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(tensor.grad[mask])
elif hasattr(tensor, 'data'):
return float(tensor.data[mask])
else:
return float(tensor[mask])
def watch_power(
tensor: torch.Tensor,
mask: torch.Tensor,
grad: bool = False,
) -> float:
if grad:
return float(10 * torch.log10((tensor.grad[mask] ** 2).mean() + EPS))
elif hasattr(tensor, 'data'):
return float(10 * torch.log10((tensor.data[mask] ** 2).mean() + EPS))
else:
return float(10 * torch.log10((tensor[mask] ** 2).mean() + EPS))
# - # - # - # - # - # - # - # - # - # - # - # - # - # - # - #
# FUNC. MAP #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
REG_FUNCTION_MAP = {
# Function mapping:
'max': watch_max,
'min': watch_min,
'mean': watch_mean,
'std': watch_std,
'var': watch_var,
'l2': watch_l2,
'l1': watch_l1,
'sparsity': watch_sparsity,
'snr': watch_snr,
'hist': watch_hist,
'rank': watch_rank,
'power': watch_power,
# Gradient mapping:
'grad_max': lambda x, y: watch_max(x, y, grad=True),
'grad_min': lambda x, y: watch_min(x, y, grad=True),
'grad_mean': lambda x, y: watch_mean(x, y, grad=True),
'grad_std': lambda x, y: watch_std(x, y, grad=True),
'grad_var': lambda x, y: watch_var(x, y, grad=True),
'grad_l1': lambda x, y: watch_l1(x, y, grad=True),
'grad_l2': lambda x, y: watch_l2(x, y, grad=True),
'grad_sparsity': lambda x, y: watch_sparsity(x, y, grad=True),
'grad_snr': lambda x, y: watch_snr(x, y, grad=True),
'grad_hist': lambda x, y: watch_hist(x, y, grad=True),
'grad_rank': lambda x, y: watch_rank(x, y, grad=True),
'grad_power': lambda x, y: watch_power(x, y, grad=True),
# Loss:
'loss': watch_any,
'val_loss': watch_any,
'lr': watch_any
}
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #
# END OF FILE #
# - x - x - x - x - x - x - x - x - x - x - x - x - x - x - #