File size: 1,821 Bytes
9093750
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from copy import deepcopy
from torch import nn
import torch as th
import pytorch_lightning as pl


from .models import CLIPClassifier
from .convnext_meta import build_covnext


class EMA(nn.Module):
    """Model Exponential Moving Average V2 from timm"""

    def __init__(self, model: nn.Module, decay: float = 0.9999):
        super(EMA, self).__init__()
        # make a copy of the model for accumulating moving average of weights
        self.module = deepcopy(model)
        self.module.eval()
        self.decay = decay

    def _update(self, model: nn.Module, update_fn):
        with th.no_grad():
            for ema_v, model_v in zip(
                self.module.state_dict().values(), model.state_dict().values()
            ):
                ema_v.copy_(update_fn(ema_v, model_v))

    def update(self, model):
        self._update(
            model, update_fn=lambda e, m: self.decay * e + (1.0 - self.decay) * m
        )

    def set(self, model):
        self._update(model, update_fn=lambda e, m: m)


class MosquitoClassifier(pl.LightningModule):
    def __init__(
        self,
        n_classes: int = 6,
        model_name: str = "ViT-L-14",
        dataset: str = None,
        head_version: int = 0,
        use_ema: bool = False,
    ):
        super().__init__()
        if dataset == "imagenet":
            self.cls = build_covnext(model_name, n_classes)

        else:
            self.cls = CLIPClassifier(n_classes, model_name, dataset, head_version)

        self.use_ema = use_ema
        if use_ema:
            self.ema = EMA(self.cls, decay=0.995)

    def forward(self, x: th.Tensor) -> th.Tensor:
        if self.use_ema and not self.training:
            print("Using EMA...")
            return nn.Softmax(dim=1)(self.ema.module(x))
        return nn.Softmax(dim=1)(self.cls(x))