File size: 1,409 Bytes
8741abe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 |
from omegaconf import OmegaConf
import torch
from typing import (
Any,
Callable,
Dict,
Iterable,
List,
NamedTuple,
NewType,
Optional,
Sized,
Tuple,
Type,
TypeVar,
Union,
)
try:
from typing import Literal
except ImportError:
from typing_extensions import Literal
# Tensor dtype
# for jaxtyping usage, see https://github.com/google/jaxtyping/blob/main/API.md
from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt
# Config type
from omegaconf import DictConfig
# PyTorch Tensor type
from torch import Tensor
# Runtime type checking decorator
from typeguard import typechecked as typechecker
def broadcast(tensor, src=0):
if not _distributed_available():
return tensor
else:
torch.distributed.broadcast(tensor, src=src)
return tensor
def _distributed_available():
return torch.distributed.is_available() and torch.distributed.is_initialized()
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
if '--local-rank' in cfg:
del cfg['--local-rank']
# added by Xavier -- delete '--local-rank' in multi-nodes training, don't know why there is such a keyword
scfg = OmegaConf.structured(fields(**cfg))
return scfg |