|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import ctypes |
|
|
from abc import abstractmethod |
|
|
from pathlib import Path |
|
|
from typing import ( |
|
|
Any, |
|
|
Dict, |
|
|
Iterable, |
|
|
Optional, |
|
|
Protocol, |
|
|
Sized, |
|
|
Type, |
|
|
TypeVar, |
|
|
Union, |
|
|
runtime_checkable, |
|
|
) |
|
|
|
|
|
import torch |
|
|
from omegaconf import DictConfig, OmegaConf |
|
|
|
|
|
root_working_dir = Path(__file__).parent.parent.parent |
|
|
|
|
|
|
|
|
def set_mkl_num_threads(): |
|
|
"""Setting mkl num threads to 1, so that we don't get thread explosion.""" |
|
|
mkl_rt = ctypes.CDLL("libmkl_rt.so") |
|
|
mkl_rt.mkl_set_num_threads(ctypes.byref(ctypes.c_int(1))) |
|
|
|
|
|
|
|
|
def working_dir_resolver(p: str): |
|
|
"""The omegaconf resolver that translates a relative path to the absolute path""" |
|
|
return "file://" + str(root_working_dir.joinpath(p).resolve()) |
|
|
|
|
|
|
|
|
def setup_conf(): |
|
|
"""Register the common Hydra config groups used in LCM (for now only the launcher)""" |
|
|
from stopes.pipelines import config_registry |
|
|
|
|
|
recipe_root = Path(__file__).parent.parent.parent / "recipes" |
|
|
config_registry["lcm-common"] = "file://" + str((recipe_root / "common").resolve()) |
|
|
config_registry["lcm-root"] = "file://" + str(recipe_root.resolve()) |
|
|
|
|
|
|
|
|
OmegaConf.register_new_resolver("realpath", working_dir_resolver, replace=True) |
|
|
|
|
|
|
|
|
def torch_type( |
|
|
dtype: Optional[Union[str, torch.dtype]] = None, |
|
|
) -> Optional[torch.dtype]: |
|
|
|
|
|
|
|
|
if dtype is None: |
|
|
return None |
|
|
|
|
|
if isinstance(dtype, torch.dtype): |
|
|
return dtype |
|
|
|
|
|
_dtype = eval(dtype) |
|
|
assert isinstance(_dtype, torch.dtype), f"Invalid dtype value: {dtype}" |
|
|
return _dtype |
|
|
|
|
|
|
|
|
@runtime_checkable |
|
|
class Batched(Sized, Protocol): |
|
|
"""Abstract class for batched data""" |
|
|
|
|
|
@abstractmethod |
|
|
def __getitem__(self, i: int) -> Any: ... |
|
|
|
|
|
|
|
|
T = TypeVar("T") |
|
|
|
|
|
|
|
|
def promote_config(config: Union[T, DictConfig, Dict], config_cls: Type[T]) -> T: |
|
|
if isinstance(config, (Dict, DictConfig)): |
|
|
import dacite |
|
|
|
|
|
if isinstance(config, DictConfig): |
|
|
config = OmegaConf.to_container(config) |
|
|
|
|
|
return dacite.from_dict( |
|
|
data_class=config_cls, |
|
|
data=config, |
|
|
config=dacite.Config(cast=[Path]), |
|
|
) |
|
|
else: |
|
|
assert isinstance(config, config_cls), f"Unknown config type: {type(config)}" |
|
|
return config |
|
|
|
|
|
|
|
|
def batched(inputs: Iterable, batch_size=10000) -> Iterable: |
|
|
batch = [] |
|
|
for line in inputs: |
|
|
batch.append(line) |
|
|
if len(batch) == batch_size: |
|
|
yield batch |
|
|
batch = [] |
|
|
if len(batch) > 0: |
|
|
yield batch |
|
|
|