from copy import deepcopy from torch import nn import torch as th import pytorch_lightning as pl from .models import CLIPClassifier from .convnext_meta import build_covnext class EMA(nn.Module): """Model Exponential Moving Average V2 from timm""" def __init__(self, model: nn.Module, decay: float = 0.9999): super(EMA, self).__init__() # make a copy of the model for accumulating moving average of weights self.module = deepcopy(model) self.module.eval() self.decay = decay def _update(self, model: nn.Module, update_fn): with th.no_grad(): for ema_v, model_v in zip( self.module.state_dict().values(), model.state_dict().values() ): ema_v.copy_(update_fn(ema_v, model_v)) def update(self, model): self._update( model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m ) def set(self, model): self._update(model, update_fn=lambda e, m: m) class MosquitoClassifier(pl.LightningModule): def __init__( self, n_classes: int = 6, model_name: str = "ViT-L-14", dataset: str = None, head_version: int = 0, use_ema: bool = False, ): super().__init__() if dataset == "imagenet": self.cls = build_covnext(model_name, n_classes) else: self.cls = CLIPClassifier(n_classes, model_name, dataset, head_version) self.use_ema = use_ema if use_ema: self.ema = EMA(self.cls, decay=0.995) def forward(self, x: th.Tensor) -> th.Tensor: if self.use_ema and not self.training: print("Using EMA...") return nn.Softmax(dim=1)(self.ema.module(x)) return nn.Softmax(dim=1)(self.cls(x))