Spaces:
Runtime error
Runtime error
File size: 2,614 Bytes
3e06e1c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
# Copyright (c) OpenMMLab. All rights reserved.
import math
from typing import Optional
import torch
import torch.nn as nn
from mmengine.model import ExponentialMovingAverage
from torch import Tensor
from mmdet.registry import MODELS
@MODELS.register_module()
class ExpMomentumEMA(ExponentialMovingAverage):
"""Exponential moving average (EMA) with exponential momentum strategy,
which is used in YOLOX.
Args:
model (nn.Module): The model to be averaged.
momentum (float): The momentum used for updating ema parameter.
Ema's parameter are updated with the formula:
`averaged_param = (1-momentum) * averaged_param + momentum *
source_param`. Defaults to 0.0002.
gamma (int): Use a larger momentum early in training and gradually
annealing to a smaller value to update the ema model smoothly. The
momentum is calculated as
`(1 - momentum) * exp(-(1 + steps) / gamma) + momentum`.
Defaults to 2000.
interval (int): Interval between two updates. Defaults to 1.
device (torch.device, optional): If provided, the averaged model will
be stored on the :attr:`device`. Defaults to None.
update_buffers (bool): if True, it will compute running averages for
both the parameters and the buffers of the model. Defaults to
False.
"""
def __init__(self,
model: nn.Module,
momentum: float = 0.0002,
gamma: int = 2000,
interval=1,
device: Optional[torch.device] = None,
update_buffers: bool = False) -> None:
super().__init__(
model=model,
momentum=momentum,
interval=interval,
device=device,
update_buffers=update_buffers)
assert gamma > 0, f'gamma must be greater than 0, but got {gamma}'
self.gamma = gamma
def avg_func(self, averaged_param: Tensor, source_param: Tensor,
steps: int) -> None:
"""Compute the moving average of the parameters using the exponential
momentum strategy.
Args:
averaged_param (Tensor): The averaged parameters.
source_param (Tensor): The source parameters.
steps (int): The number of times the parameters have been
updated.
"""
momentum = (1 - self.momentum) * math.exp(
-float(1 + steps) / self.gamma) + self.momentum
averaged_param.mul_(1 - momentum).add_(source_param, alpha=momentum)
|