Spaces:
Sleeping
Sleeping
from copy import deepcopy | |
from torch import nn | |
import torch as th | |
import pytorch_lightning as pl | |
from .models import CLIPClassifier | |
from .convnext_meta import build_covnext | |
class EMA(nn.Module): | |
"""Model Exponential Moving Average V2 from timm""" | |
def __init__(self, model: nn.Module, decay: float = 0.9999): | |
super(EMA, self).__init__() | |
# make a copy of the model for accumulating moving average of weights | |
self.module = deepcopy(model) | |
self.module.eval() | |
self.decay = decay | |
def _update(self, model: nn.Module, update_fn): | |
with th.no_grad(): | |
for ema_v, model_v in zip( | |
self.module.state_dict().values(), model.state_dict().values() | |
): | |
ema_v.copy_(update_fn(ema_v, model_v)) | |
def update(self, model): | |
self._update( | |
model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m | |
) | |
def set(self, model): | |
self._update(model, update_fn=lambda e, m: m) | |
class MosquitoClassifier(pl.LightningModule): | |
def __init__( | |
self, | |
n_classes: int = 6, | |
model_name: str = "ViT-L-14", | |
dataset: str = None, | |
head_version: int = 0, | |
use_ema: bool = False, | |
): | |
super().__init__() | |
if dataset == "imagenet": | |
self.cls = build_covnext(model_name, n_classes) | |
else: | |
self.cls = CLIPClassifier(n_classes, model_name, dataset, head_version) | |
self.use_ema = use_ema | |
if use_ema: | |
self.ema = EMA(self.cls, decay=0.995) | |
def forward(self, x: th.Tensor) -> th.Tensor: | |
if self.use_ema and not self.training: | |
print("Using EMA...") | |
return nn.Softmax(dim=1)(self.ema.module(x)) | |
return nn.Softmax(dim=1)(self.cls(x)) | |