|
import collections |
|
import importlib |
|
import logging |
|
import os |
|
import time |
|
from collections import OrderedDict |
|
from collections.abc import Sequence |
|
from itertools import repeat |
|
|
|
import numpy as np |
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
def print_rank(var_name, var_value, rank=0): |
|
if dist.get_rank() == rank: |
|
print(f"[Rank {rank}] {var_name}: {var_value}") |
|
|
|
|
|
def print_0(*args, **kwargs): |
|
if dist.get_rank() == 0: |
|
print(*args, **kwargs) |
|
|
|
|
|
def requires_grad(model: torch.nn.Module, flag: bool = True) -> None: |
|
""" |
|
Set requires_grad flag for all parameters in a model. |
|
""" |
|
for p in model.parameters(): |
|
p.requires_grad = flag |
|
|
|
|
|
def format_numel_str(numel: int) -> str: |
|
B = 1024**3 |
|
M = 1024**2 |
|
K = 1024 |
|
if numel >= B: |
|
return f"{numel / B:.2f} B" |
|
elif numel >= M: |
|
return f"{numel / M:.2f} M" |
|
elif numel >= K: |
|
return f"{numel / K:.2f} K" |
|
else: |
|
return f"{numel}" |
|
|
|
|
|
def all_reduce_mean(tensor: torch.Tensor) -> torch.Tensor: |
|
dist.all_reduce(tensor=tensor, op=dist.ReduceOp.SUM) |
|
tensor.div_(dist.get_world_size()) |
|
return tensor |
|
|
|
|
|
def get_model_numel(model: torch.nn.Module) -> (int, int): |
|
num_params = 0 |
|
num_params_trainable = 0 |
|
for p in model.parameters(): |
|
num_params += p.numel() |
|
if p.requires_grad: |
|
num_params_trainable += p.numel() |
|
return num_params, num_params_trainable |
|
|
|
|
|
def try_import(name): |
|
"""Try to import a module. |
|
|
|
Args: |
|
name (str): Specifies what module to import in absolute or relative |
|
terms (e.g. either pkg.mod or ..mod). |
|
Returns: |
|
ModuleType or None: If importing successfully, returns the imported |
|
module, otherwise returns None. |
|
""" |
|
try: |
|
return importlib.import_module(name) |
|
except ImportError: |
|
return None |
|
|
|
|
|
def transpose(x): |
|
""" |
|
transpose a list of list |
|
Args: |
|
x (list[list]): |
|
""" |
|
ret = list(map(list, zip(*x))) |
|
return ret |
|
|
|
|
|
def get_timestamp(): |
|
timestamp = time.strftime("%Y%m%d-%H%M%S", time.localtime(time.time())) |
|
return timestamp |
|
|
|
|
|
def format_time(seconds): |
|
days = int(seconds / 3600 / 24) |
|
seconds = seconds - days * 3600 * 24 |
|
hours = int(seconds / 3600) |
|
seconds = seconds - hours * 3600 |
|
minutes = int(seconds / 60) |
|
seconds = seconds - minutes * 60 |
|
secondsf = int(seconds) |
|
seconds = seconds - secondsf |
|
millis = int(seconds * 1000) |
|
|
|
f = "" |
|
i = 1 |
|
if days > 0: |
|
f += str(days) + "D" |
|
i += 1 |
|
if hours > 0 and i <= 2: |
|
f += str(hours) + "h" |
|
i += 1 |
|
if minutes > 0 and i <= 2: |
|
f += str(minutes) + "m" |
|
i += 1 |
|
if secondsf > 0 and i <= 2: |
|
f += str(secondsf) + "s" |
|
i += 1 |
|
if millis > 0 and i <= 2: |
|
f += str(millis) + "ms" |
|
i += 1 |
|
if f == "": |
|
f = "0ms" |
|
return f |
|
|
|
|
|
def to_tensor(data): |
|
"""Convert objects of various python types to :obj:`torch.Tensor`. |
|
|
|
Supported types are: :class:`numpy.ndarray`, :class:`torch.Tensor`, |
|
:class:`Sequence`, :class:`int` and :class:`float`. |
|
|
|
Args: |
|
data (torch.Tensor | numpy.ndarray | Sequence | int | float): Data to |
|
be converted. |
|
""" |
|
|
|
if isinstance(data, torch.Tensor): |
|
return data |
|
elif isinstance(data, np.ndarray): |
|
return torch.from_numpy(data) |
|
elif isinstance(data, Sequence) and not isinstance(data, str): |
|
return torch.tensor(data) |
|
elif isinstance(data, int): |
|
return torch.LongTensor([data]) |
|
elif isinstance(data, float): |
|
return torch.FloatTensor([data]) |
|
else: |
|
raise TypeError(f"type {type(data)} cannot be converted to tensor.") |
|
|
|
|
|
def to_ndarray(data): |
|
if isinstance(data, torch.Tensor): |
|
return data.numpy() |
|
elif isinstance(data, np.ndarray): |
|
return data |
|
elif isinstance(data, Sequence): |
|
return np.array(data) |
|
elif isinstance(data, int): |
|
return np.ndarray([data], dtype=int) |
|
elif isinstance(data, float): |
|
return np.array([data], dtype=float) |
|
else: |
|
raise TypeError(f"type {type(data)} cannot be converted to ndarray.") |
|
|
|
|
|
def to_torch_dtype(dtype): |
|
if isinstance(dtype, torch.dtype): |
|
return dtype |
|
elif isinstance(dtype, str): |
|
dtype_mapping = { |
|
"float64": torch.float64, |
|
"float32": torch.float32, |
|
"float16": torch.float16, |
|
"fp32": torch.float32, |
|
"fp16": torch.float16, |
|
"half": torch.float16, |
|
"bf16": torch.bfloat16, |
|
} |
|
if dtype not in dtype_mapping: |
|
raise ValueError |
|
dtype = dtype_mapping[dtype] |
|
return dtype |
|
else: |
|
raise ValueError |
|
|
|
|
|
def count_params(model): |
|
return sum(p.numel() for p in model.parameters() if p.requires_grad) |
|
|
|
|
|
def _ntuple(n): |
|
def parse(x): |
|
if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): |
|
return x |
|
return tuple(repeat(x, n)) |
|
|
|
return parse |
|
|
|
|
|
to_1tuple = _ntuple(1) |
|
to_2tuple = _ntuple(2) |
|
to_3tuple = _ntuple(3) |
|
to_4tuple = _ntuple(4) |
|
to_ntuple = _ntuple |
|
|
|
|
|
def convert_SyncBN_to_BN2d(model_cfg): |
|
for k in model_cfg: |
|
v = model_cfg[k] |
|
if k == "norm_cfg" and v["type"] == "SyncBN": |
|
v["type"] = "BN2d" |
|
elif isinstance(v, dict): |
|
convert_SyncBN_to_BN2d(v) |
|
|
|
|
|
def get_topk(x, dim=4, k=5): |
|
x = to_tensor(x) |
|
inds = x[..., dim].topk(k)[1] |
|
return x[inds] |
|
|
|
|
|
def param_sigmoid(x, alpha): |
|
ret = 1 / (1 + (-alpha * x).exp()) |
|
return ret |
|
|
|
|
|
def inverse_param_sigmoid(x, alpha, eps=1e-5): |
|
x = x.clamp(min=0, max=1) |
|
x1 = x.clamp(min=eps) |
|
x2 = (1 - x).clamp(min=eps) |
|
return torch.log(x1 / x2) / alpha |
|
|
|
|
|
def inverse_sigmoid(x, eps=1e-5): |
|
"""Inverse function of sigmoid. |
|
|
|
Args: |
|
x (Tensor): The tensor to do the |
|
inverse. |
|
eps (float): EPS avoid numerical |
|
overflow. Defaults 1e-5. |
|
Returns: |
|
Tensor: The x has passed the inverse |
|
function of sigmoid, has same |
|
shape with input. |
|
""" |
|
x = x.clamp(min=0, max=1) |
|
x1 = x.clamp(min=eps) |
|
x2 = (1 - x).clamp(min=eps) |
|
return torch.log(x1 / x2) |
|
|
|
|
|
def count_columns(df, columns): |
|
cnt_dict = OrderedDict() |
|
num_samples = len(df) |
|
|
|
for col in columns: |
|
d_i = df[col].value_counts().to_dict() |
|
for k in d_i: |
|
d_i[k] = (d_i[k], d_i[k] / num_samples) |
|
cnt_dict[col] = d_i |
|
|
|
return cnt_dict |
|
|
|
|
|
def build_logger(work_dir, cfgname): |
|
log_file = cfgname + ".log" |
|
log_path = os.path.join(work_dir, log_file) |
|
|
|
logger = logging.getLogger(cfgname) |
|
logger.setLevel(logging.INFO) |
|
|
|
formatter = logging.Formatter("%(asctime)s: %(message)s", datefmt="%Y-%m-%d %H:%M:%S") |
|
|
|
handler1 = logging.FileHandler(log_path) |
|
handler1.setFormatter(formatter) |
|
|
|
handler2 = logging.StreamHandler() |
|
handler2.setFormatter(formatter) |
|
|
|
logger.addHandler(handler1) |
|
logger.addHandler(handler2) |
|
logger.propagate = False |
|
|
|
return logger |
|
|