File size: 989 Bytes
ad16788 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
def add_gradient_noise(
model: torch.nn.Module,
iteration: int,
duration: float = 100,
eta: float = 1.0,
scale_factor: float = 0.55,
):
"""Adds noise from a standard normal distribution to the gradients.
The standard deviation (`sigma`) is controlled
by the three hyper-parameters below.
`sigma` goes to zero (no noise) with more iterations.
Args:
model: Model.
iteration: Number of iterations.
duration: {100, 1000}: Number of durations to control
the interval of the `sigma` change.
eta: {0.01, 0.3, 1.0}: The magnitude of `sigma`.
scale_factor: {0.55}: The scale of `sigma`.
"""
interval = (iteration // duration) + 1
sigma = eta / interval ** scale_factor
for param in model.parameters():
if param.grad is not None:
_shape = param.grad.size()
noise = sigma * torch.randn(_shape).to(param.device)
param.grad += noise
|