RFTSystems commited on
Commit
bb44a6f
·
verified ·
1 Parent(s): 9ffa31d

Update dclr_optimizer.py

Browse files
Files changed (1) hide show
  1. dclr_optimizer.py +46 -0
dclr_optimizer.py CHANGED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import math
4
+ from torch.optim import Optimizer # Ensure Optimizer is imported for custom classes
5
+
6
+ class DCLR(Optimizer):
7
+ def __init__(self, params, lr=0.01, lambda_=1.0, epsilon=1e-8, delta=1e-12, verbose=True):
8
+ defaults = dict(lr=lr, lambda_=lambda_, epsilon=epsilon, delta=delta, verbose=verbose)
9
+ super(DCLR, self).__init__(params, defaults)
10
+
11
+ def step(self, closure=None, output_activations=None):
12
+ if output_activations is None:
13
+ raise ValueError("Output activations must be provided to compute entropy.")
14
+
15
+ loss = None
16
+ if closure is not None:
17
+ loss = closure()
18
+
19
+ probs = torch.nn.functional.softmax(output_activations, dim=1)
20
+ log_probs = torch.log(probs + self.defaults['delta'])
21
+ entropy = -torch.sum(probs * log_probs, dim=1).mean()
22
+
23
+ for group in self.param_groups:
24
+ lr_0 = group['lr']
25
+ lambda_ = group['lambda_']
26
+ epsilon = group['epsilon']
27
+ verbose = group['verbose']
28
+
29
+ for p in group['params']:
30
+ if p.grad is None:
31
+ continue
32
+ grad = p.grad.data
33
+ grad_norm_sq = grad.norm() ** 2
34
+
35
+ eta_t = lr_0 * math.exp(-lambda_ * grad_norm_sq.item() / (entropy.item() + epsilon))
36
+
37
+ if verbose:
38
+ print(f"[DCLR] Entropy: {entropy.item():.6f} | GradNorm²: {grad_norm_sq.item():.6f} | η(t): {eta_t:.6e}")
39
+
40
+ # Fix for UserWarning: This overload of add_ is deprecated:
41
+ # add_(Number alpha, Tensor other)
42
+ # Consider using one of the following signatures instead:
43
+ # add_(Tensor other, *, Number alpha = 1)
44
+ p.data.add_(grad, alpha=-eta_t)
45
+
46
+ return loss