| import torch |
|
|
| class EMA(torch.nn.Module): |
| def __init__(self, model: torch.nn.Module, decay: float = 0.999): |
| super().__init__() |
| self.model = model |
| self.decay = decay |
| if hasattr(self.model, "time_geopath"): |
| self.time_geopath = self.model.time_geopath |
|
|
| |
| self.register_buffer("num_updates", torch.tensor(0)) |
|
|
| self.shadow_params = torch.nn.ParameterList( |
| [ |
| torch.nn.Parameter(p.clone().detach(), requires_grad=False) |
| for p in model.parameters() |
| if p.requires_grad |
| ] |
| ) |
| self.backup_params = [] |
|
|
| def train(self, mode: bool): |
| if self.training and mode == False: |
| |
| |
| self.backup() |
| self.copy_to_model() |
| elif not self.training and mode == True: |
| |
| self.restore_to_model() |
|
|
| super().train(mode) |
|
|
| def update_ema(self): |
| self.num_updates += 1 |
| num_updates = self.num_updates.item() |
| decay = min(self.decay, (1 + num_updates) / (10 + num_updates)) |
| with torch.no_grad(): |
| params = [p for p in self.model.parameters() if p.requires_grad] |
| for shadow, param in zip(self.shadow_params, params): |
| shadow.sub_((1 - decay) * (shadow - param)) |
|
|
| def forward(self, *args, **kwargs): |
| return self.model(*args, **kwargs) |
|
|
| def copy_to_model(self): |
| |
| params = [p for p in self.model.parameters() if p.requires_grad] |
| for shaddow, param in zip(self.shadow_params, params): |
| param.data.copy_(shaddow.data) |
|
|
| def backup(self): |
| |
| if len(self.backup_params) > 0: |
| for p, b in zip(self.model.parameters(), self.backup_params): |
| b.data.copy_(p.data) |
| else: |
| self.backup_params = [param.clone() for param in self.model.parameters()] |
|
|
| def restore_to_model(self): |
| |
| for param, backup in zip(self.model.parameters(), self.backup_params): |
| param.data.copy_(backup.data) |