DCLR_Optimiser / dclr_optimizer.py
RFTSystems's picture
Update dclr_optimizer.py
bb44a6f verified
import torch
import torch.nn.functional as F
import math
from torch.optim import Optimizer # Ensure Optimizer is imported for custom classes
class DCLR(Optimizer):
def __init__(self, params, lr=0.01, lambda_=1.0, epsilon=1e-8, delta=1e-12, verbose=True):
defaults = dict(lr=lr, lambda_=lambda_, epsilon=epsilon, delta=delta, verbose=verbose)
super(DCLR, self).__init__(params, defaults)
def step(self, closure=None, output_activations=None):
if output_activations is None:
raise ValueError("Output activations must be provided to compute entropy.")
loss = None
if closure is not None:
loss = closure()
probs = torch.nn.functional.softmax(output_activations, dim=1)
log_probs = torch.log(probs + self.defaults['delta'])
entropy = -torch.sum(probs * log_probs, dim=1).mean()
for group in self.param_groups:
lr_0 = group['lr']
lambda_ = group['lambda_']
epsilon = group['epsilon']
verbose = group['verbose']
for p in group['params']:
if p.grad is None:
continue
grad = p.grad.data
grad_norm_sq = grad.norm() ** 2
eta_t = lr_0 * math.exp(-lambda_ * grad_norm_sq.item() / (entropy.item() + epsilon))
if verbose:
print(f"[DCLR] Entropy: {entropy.item():.6f} | GradNorm²: {grad_norm_sq.item():.6f} | η(t): {eta_t:.6e}")
# Fix for UserWarning: This overload of add_ is deprecated:
# add_(Number alpha, Tensor other)
# Consider using one of the following signatures instead:
# add_(Tensor other, *, Number alpha = 1)
p.data.add_(grad, alpha=-eta_t)
return loss