| from omegaconf import OmegaConf |
| import torch |
| from typing import ( |
| Any, |
| Callable, |
| Dict, |
| Iterable, |
| List, |
| NamedTuple, |
| NewType, |
| Optional, |
| Sized, |
| Tuple, |
| Type, |
| TypeVar, |
| Union, |
| ) |
| try: |
| from typing import Literal |
| except ImportError: |
| from typing_extensions import Literal |
|
|
| |
| |
| from jaxtyping import Bool, Complex, Float, Inexact, Int, Integer, Num, Shaped, UInt |
|
|
| |
| from omegaconf import DictConfig |
|
|
| |
| from torch import Tensor |
|
|
| |
| from typeguard import typechecked as typechecker |
|
|
|
|
| def broadcast(tensor, src=0): |
| if not _distributed_available(): |
| return tensor |
| else: |
| torch.distributed.broadcast(tensor, src=src) |
| return tensor |
|
|
| def _distributed_available(): |
| return torch.distributed.is_available() and torch.distributed.is_initialized() |
|
|
| def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any: |
| |
| if '--local-rank' in cfg: |
| del cfg['--local-rank'] |
| |
| scfg = OmegaConf.structured(fields(**cfg)) |
| return scfg |