Spaces:
Runtime error
Runtime error
import torch | |
from torch import nn | |
class DDPGradientStatsHook: | |
def __init__(self, ddp_module): | |
try: | |
ddp_module.register_comm_hook(self, self._hook_fn) | |
except AttributeError: | |
raise ValueError('DDPGradientStatsHook does not support non-DDP wrapped modules') | |
self._clear_state() | |
def _clear_state(self): | |
self.bucket_sq_norms_small_batch = [] | |
self.bucket_sq_norms_large_batch = [] | |
def _hook_fn(self, bucket): | |
buf = bucket.buffer() | |
self.bucket_sq_norms_small_batch.append(buf.pow(2).sum(dtype=torch.float32)) | |
fut = torch.distributed.all_reduce(buf, op=torch.distributed.ReduceOp.AVG, async_op=True).get_future() | |
def callback(fut): | |
buf = fut.value()[0] | |
self.bucket_sq_norms_large_batch.append(buf.pow(2).sum(dtype=torch.float32)) | |
return buf | |
return fut.then(callback) | |
def get_stats(self): | |
sq_norm_small_batch = sum(self.bucket_sq_norms_small_batch) | |
sq_norm_large_batch = sum(self.bucket_sq_norms_large_batch) | |
self._clear_state() | |
stats = torch.stack([sq_norm_small_batch, sq_norm_large_batch]) | |
torch.distributed.all_reduce(stats, op=torch.distributed.ReduceOp.AVG) | |
return stats[0].item(), stats[1].item() | |
class GradientNoiseScale: | |
"""Calculates the gradient noise scale (1 / SNR), or critical batch size, | |
from _An Empirical Model of Large-Batch Training_, | |
https://arxiv.org/abs/1812.06162). | |
Args: | |
beta (float): The decay factor for the exponential moving averages used to | |
calculate the gradient noise scale. | |
Default: 0.9998 | |
eps (float): Added for numerical stability. | |
Default: 1e-8 | |
""" | |
def __init__(self, beta=0.9998, eps=1e-8): | |
self.beta = beta | |
self.eps = eps | |
self.ema_sq_norm = 0. | |
self.ema_var = 0. | |
self.beta_cumprod = 1. | |
self.gradient_noise_scale = float('nan') | |
def state_dict(self): | |
"""Returns the state of the object as a :class:`dict`.""" | |
return dict(self.__dict__.items()) | |
def load_state_dict(self, state_dict): | |
"""Loads the object's state. | |
Args: | |
state_dict (dict): object state. Should be an object returned | |
from a call to :meth:`state_dict`. | |
""" | |
self.__dict__.update(state_dict) | |
def update(self, sq_norm_small_batch, sq_norm_large_batch, n_small_batch, n_large_batch): | |
"""Updates the state with a new batch's gradient statistics, and returns the | |
current gradient noise scale. | |
Args: | |
sq_norm_small_batch (float): The mean of the squared 2-norms of microbatch or | |
per sample gradients. | |
sq_norm_large_batch (float): The squared 2-norm of the mean of the microbatch or | |
per sample gradients. | |
n_small_batch (int): The batch size of the individual microbatch or per sample | |
gradients (1 if per sample). | |
n_large_batch (int): The total batch size of the mean of the microbatch or | |
per sample gradients. | |
""" | |
est_sq_norm = (n_large_batch * sq_norm_large_batch - n_small_batch * sq_norm_small_batch) / (n_large_batch - n_small_batch) | |
est_var = (sq_norm_small_batch - sq_norm_large_batch) / (1 / n_small_batch - 1 / n_large_batch) | |
self.ema_sq_norm = self.beta * self.ema_sq_norm + (1 - self.beta) * est_sq_norm | |
self.ema_var = self.beta * self.ema_var + (1 - self.beta) * est_var | |
self.beta_cumprod *= self.beta | |
self.gradient_noise_scale = max(self.ema_var, self.eps) / max(self.ema_sq_norm, self.eps) | |
return self.gradient_noise_scale | |
def get_gns(self): | |
"""Returns the current gradient noise scale.""" | |
return self.gradient_noise_scale | |
def get_stats(self): | |
"""Returns the current (debiased) estimates of the squared mean gradient | |
and gradient variance.""" | |
return self.ema_sq_norm / (1 - self.beta_cumprod), self.ema_var / (1 - self.beta_cumprod) | |