File size: 3,067 Bytes
ad5354d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import numpy as np
import torch
import torch.nn as nn

from src.efficientvit.apps.trainer.run_config import Scheduler
from src.efficientvit.models.nn.ops import IdentityLayer, ResidualBlock
from src.efficientvit.models.utils import build_kwargs_from_config

__all__ = ["apply_drop_func"]


def apply_drop_func(network: nn.Module, drop_config: dict[str, any] or None) -> None:
    if drop_config is None:
        return

    drop_lookup_table = {
        "droppath": apply_droppath,
    }

    drop_func = drop_lookup_table[drop_config["name"]]
    drop_kwargs = build_kwargs_from_config(drop_config, drop_func)

    drop_func(network, **drop_kwargs)


def apply_droppath(
    network: nn.Module,
    drop_prob: float,
    linear_decay=True,
    scheduled=True,
    skip=0,
) -> None:
    all_valid_blocks = []
    for m in network.modules():
        for name, sub_module in m.named_children():
            if isinstance(sub_module, ResidualBlock) and isinstance(
                sub_module.shortcut, IdentityLayer
            ):
                all_valid_blocks.append((m, name, sub_module))
    all_valid_blocks = all_valid_blocks[skip:]
    for i, (m, name, sub_module) in enumerate(all_valid_blocks):
        prob = (
            drop_prob * (i + 1) / len(all_valid_blocks) if linear_decay else drop_prob
        )
        new_module = DropPathResidualBlock(
            sub_module.main,
            sub_module.shortcut,
            sub_module.post_act,
            sub_module.pre_norm,
            prob,
            scheduled,
        )
        m._modules[name] = new_module


class DropPathResidualBlock(ResidualBlock):
    def __init__(
        self,
        main: nn.Module,
        shortcut: nn.Module or None,
        post_act=None,
        pre_norm: nn.Module or None = None,
        ######################################
        drop_prob: float = 0,
        scheduled=True,
    ):
        super().__init__(main, shortcut, post_act, pre_norm)

        self.drop_prob = drop_prob
        self.scheduled = scheduled

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if (
            not self.training
            or self.drop_prob == 0
            or not isinstance(self.shortcut, IdentityLayer)
        ):
            return ResidualBlock.forward(self, x)
        else:
            drop_prob = self.drop_prob
            if self.scheduled:
                drop_prob *= np.clip(Scheduler.PROGRESS, 0, 1)
            keep_prob = 1 - drop_prob

            shape = (x.shape[0],) + (1,) * (x.ndim - 1)
            random_tensor = keep_prob + torch.rand(
                shape, dtype=x.dtype, device=x.device
            )
            random_tensor.floor_()  # binarize

            res = self.forward_main(x) / keep_prob * random_tensor + self.shortcut(x)
            if self.post_act:
                res = self.post_act(res)
            return res