File size: 4,774 Bytes
ad5354d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
# EfficientViT: Multi-Scale Linear Attention for High-Resolution Dense Prediction
# Han Cai, Junyan Li, Muyan Hu, Chuang Gan, Song Han
# International Conference on Computer Vision (ICCV), 2023

import os
import time
from copy import deepcopy

import torch.backends.cudnn
import torch.distributed
import torch.nn as nn

from src.efficientvit.apps.data_provider import DataProvider
from src.efficientvit.apps.trainer.run_config import RunConfig
from src.efficientvit.apps.utils import (dist_init, dump_config,
                                     get_dist_local_rank, get_dist_rank,
                                     get_dist_size, init_modules, is_master,
                                     load_config, partial_update_config,
                                     zero_last_gamma)
from src.efficientvit.models.utils import (build_kwargs_from_config,
                                       load_state_dict_from_file)

__all__ = [
    "save_exp_config",
    "setup_dist_env",
    "setup_seed",
    "setup_exp_config",
    "setup_data_provider",
    "setup_run_config",
    "init_model",
]


def save_exp_config(exp_config: dict, path: str, name="config.yaml") -> None:
    if not is_master():
        return
    dump_config(exp_config, os.path.join(path, name))


def setup_dist_env(gpu: str or None = None) -> None:
    if gpu is not None:
        os.environ["CUDA_VISIBLE_DEVICES"] = gpu
    if not torch.distributed.is_initialized():
        dist_init()
    torch.backends.cudnn.benchmark = True
    torch.cuda.set_device(get_dist_local_rank())


def setup_seed(manual_seed: int, resume: bool) -> None:
    if resume:
        manual_seed = int(time.time())
    manual_seed = get_dist_rank() + manual_seed
    torch.manual_seed(manual_seed)
    torch.cuda.manual_seed_all(manual_seed)


def setup_exp_config(
    config_path: str, recursive=True, opt_args: dict or None = None
) -> dict:
    # load config
    if not os.path.isfile(config_path):
        raise ValueError(config_path)

    fpaths = [config_path]
    if recursive:
        extension = os.path.splitext(config_path)[1]
        while os.path.dirname(config_path) != config_path:
            config_path = os.path.dirname(config_path)
            fpath = os.path.join(config_path, "default" + extension)
            if os.path.isfile(fpath):
                fpaths.append(fpath)
        fpaths = fpaths[::-1]

    default_config = load_config(fpaths[0])
    exp_config = deepcopy(default_config)
    for fpath in fpaths[1:]:
        partial_update_config(exp_config, load_config(fpath))
    # update config via args
    if opt_args is not None:
        partial_update_config(exp_config, opt_args)

    return exp_config


def setup_data_provider(
    exp_config: dict,
    data_provider_classes: list[type[DataProvider]],
    is_distributed: bool = True,
) -> DataProvider:
    dp_config = exp_config["data_provider"]
    dp_config["num_replicas"] = get_dist_size() if is_distributed else None
    dp_config["rank"] = get_dist_rank() if is_distributed else None
    dp_config["test_batch_size"] = (
        dp_config.get("test_batch_size", None) or dp_config["base_batch_size"] * 2
    )
    dp_config["batch_size"] = dp_config["train_batch_size"] = dp_config[
        "base_batch_size"
    ]

    data_provider_lookup = {
        provider.name: provider for provider in data_provider_classes
    }
    data_provider_class = data_provider_lookup[dp_config["dataset"]]

    data_provider_kwargs = build_kwargs_from_config(dp_config, data_provider_class)
    data_provider = data_provider_class(**data_provider_kwargs)
    return data_provider


def setup_run_config(exp_config: dict, run_config_cls: type[RunConfig]) -> RunConfig:
    exp_config["run_config"]["init_lr"] = (
        exp_config["run_config"]["base_lr"] * get_dist_size()
    )

    run_config = run_config_cls(**exp_config["run_config"])

    return run_config


def init_model(
    network: nn.Module,
    init_from: str or None = None,
    backbone_init_from: str or None = None,
    rand_init="trunc_normal",
    last_gamma=None,
) -> None:
    # initialization
    init_modules(network, init_type=rand_init)
    # zero gamma of last bn in each block
    if last_gamma is not None:
        zero_last_gamma(network, last_gamma)

    # load weight
    if init_from is not None and os.path.isfile(init_from):
        network.load_state_dict(load_state_dict_from_file(init_from))
        print(f"Loaded init from {init_from}")
    elif backbone_init_from is not None and os.path.isfile(backbone_init_from):
        network.backbone.load_state_dict(load_state_dict_from_file(backbone_init_from))
        print(f"Loaded backbone init from {backbone_init_from}")
    else:
        print(f"Random init ({rand_init}) with last gamma {last_gamma}")