Spaces:
Sleeping
Sleeping
File size: 1,821 Bytes
9093750 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from copy import deepcopy
from torch import nn
import torch as th
import pytorch_lightning as pl
from .models import CLIPClassifier
from .convnext_meta import build_covnext
class EMA(nn.Module):
"""Model Exponential Moving Average V2 from timm"""
def __init__(self, model: nn.Module, decay: float = 0.9999):
super(EMA, self).__init__()
# make a copy of the model for accumulating moving average of weights
self.module = deepcopy(model)
self.module.eval()
self.decay = decay
def _update(self, model: nn.Module, update_fn):
with th.no_grad():
for ema_v, model_v in zip(
self.module.state_dict().values(), model.state_dict().values()
):
ema_v.copy_(update_fn(ema_v, model_v))
def update(self, model):
self._update(
model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m
)
def set(self, model):
self._update(model, update_fn=lambda e, m: m)
class MosquitoClassifier(pl.LightningModule):
def __init__(
self,
n_classes: int = 6,
model_name: str = "ViT-L-14",
dataset: str = None,
head_version: int = 0,
use_ema: bool = False,
):
super().__init__()
if dataset == "imagenet":
self.cls = build_covnext(model_name, n_classes)
else:
self.cls = CLIPClassifier(n_classes, model_name, dataset, head_version)
self.use_ema = use_ema
if use_ema:
self.ema = EMA(self.cls, decay=0.995)
def forward(self, x: th.Tensor) -> th.Tensor:
if self.use_ema and not self.training:
print("Using EMA...")
return nn.Softmax(dim=1)(self.ema.module(x))
return nn.Softmax(dim=1)(self.cls(x))
|