KyanChen's picture
Upload 787 files
3e06e1c
raw
history blame
2.61 kB
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn as nn
from mmengine.model import ExponentialMovingAverage
from torch import Tensor
from mmdet.registry import MODELS
@MODELS.register_module()
class ExpMomentumEMA(ExponentialMovingAverage):
"""Exponential moving average (EMA) with exponential momentum strategy,
which is used in YOLOX.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
Defaults to 2000.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
gamma: int = 2000,
interval=1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
self.gamma = gamma
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> None:
"""Compute the moving average of the parameters using the exponential
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
"""
momentum = (1 - self.momentum) * math.exp(
-float(1 + steps) / self.gamma) + self.momentum
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)