Spaces:
Running
Running
import torch | |
from torch.optim import Optimizer | |
class PerPointAdam(Optimizer): | |
"""Implements Adam optimizer with per-point learning rates. | |
Allows unique learning rates for each point in specified parameter tensors, | |
useful for point cloud optimization. | |
Args: | |
params: Iterable of parameters to optimize or parameter groups | |
lr (float, optional): Default learning rate (default: 1e-3) | |
betas (tuple, optional): Coefficients for moving averages (default: (0.9, 0.999)) | |
eps (float, optional): Term for numerical stability (default: 1e-8) | |
weight_decay (float, optional): Weight decay (L2 penalty) (default: 0) | |
""" | |
def __init__(self, params, lr=1e-3, betas=(0.9, 0.999), eps=1e-8, weight_decay=0): | |
if not all(0.0 <= x for x in [lr, eps, weight_decay]): | |
raise ValueError(f"Invalid learning parameters: lr={lr}, eps={eps}, weight_decay={weight_decay}") | |
if not all(0.0 <= beta < 1.0 for beta in betas): | |
raise ValueError(f"Invalid beta parameters: {betas}") | |
defaults = dict(lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, per_point_lr=None) | |
super().__init__(params, defaults) | |
def _adjust_per_point_lr(self, per_point_lr, grad, mask): | |
"""Adjusts per-point learning rates based on gradient magnitudes.""" | |
grad_magnitude = grad.norm(dim=-1) | |
scaling_factor = torch.ones_like(grad_magnitude) | |
grad_sigmoid = torch.sigmoid(grad_magnitude[mask]) | |
scaling_factor[mask] = 0.99 + (grad_sigmoid * 0.02) | |
return per_point_lr * scaling_factor.unsqueeze(1) | |
def step(self, closure=None): | |
"""Performs a single optimization step.""" | |
loss = closure() if closure is not None else None | |
for group in self.param_groups: | |
per_point_lr = group.get('per_point_lr') | |
for p in group['params']: | |
if p.grad is None: | |
continue | |
grad = p.grad.data | |
if grad.is_sparse: | |
raise RuntimeError('PerPointAdam does not support sparse gradients') | |
# Initialize state if needed | |
state = self.state[p] | |
if len(state) == 0: | |
state['step'] = 0 | |
state['exp_avg'] = torch.zeros_like(p.data) | |
state['exp_avg_sq'] = torch.zeros_like(p.data) | |
# Get state values | |
exp_avg, exp_avg_sq = state['exp_avg'], state['exp_avg_sq'] | |
beta1, beta2 = group['betas'] | |
state['step'] += 1 | |
# Apply weight decay if specified | |
if group['weight_decay'] != 0: | |
grad = grad.add(p.data, alpha=group['weight_decay']) | |
# Compute mask for non-zero gradients | |
grad_norm = grad.norm() | |
mask = grad_norm > 0 | |
# Update momentum terms | |
exp_avg.masked_scatter_(mask, | |
exp_avg[mask].mul_(beta1).add_(grad[mask], alpha=1 - beta1)) | |
exp_avg_sq.masked_scatter_(mask, | |
exp_avg_sq[mask].mul_(beta2).addcmul_(grad[mask], grad[mask], value=1 - beta2)) | |
# Compute bias corrections | |
bias_correction1 = 1 - beta1 ** state['step'] | |
bias_correction2 = 1 - beta2 ** state['step'] | |
# Compute step size | |
denom = exp_avg_sq.sqrt().add_(group['eps']) | |
step_size = group['lr'] * (bias_correction2 ** 0.5 / bias_correction1) | |
# Apply updates | |
if per_point_lr is not None: | |
if not isinstance(per_point_lr, torch.Tensor): | |
raise TypeError("per_point_lr must be a torch.Tensor") | |
if per_point_lr.device != p.data.device: | |
raise ValueError("per_point_lr must be on the same device as parameter") | |
expected_shape = p.data.shape[:1] + (1,) * (p.data.dim() - 1) | |
if per_point_lr.shape != expected_shape: | |
raise ValueError(f"{group['name']}: Invalid per_point_lr shape. Expected {expected_shape}, got {per_point_lr.shape}") | |
scaled_step_size = step_size * per_point_lr | |
p.data.add_(-scaled_step_size * (exp_avg / denom)) | |
per_point_lr = self._adjust_per_point_lr(per_point_lr, grad, mask) | |
else: | |
p.data.addcdiv_(exp_avg, denom, value=-step_size) | |
return loss |