|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Provides cluster and tools configuration across clusters (slurm, dora, utilities). |
|
""" |
|
|
|
import logging |
|
import os |
|
from pathlib import Path |
|
import re |
|
import typing as tp |
|
|
|
import omegaconf |
|
|
|
from .utils.cluster import _guess_cluster_type |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class AudioCraftEnvironment: |
|
"""Environment configuration for teams and clusters. |
|
|
|
AudioCraftEnvironment picks compute cluster settings (slurm, dora) from the current running environment |
|
or declared variable and the loaded team configuration. Additionally, the AudioCraftEnvironment |
|
provides pointers to a reference folder resolved automatically across clusters that is shared across team members, |
|
allowing to share sigs or other files to run jobs. Finally, it provides dataset mappers to automatically |
|
map dataset file paths to new locations across clusters, allowing to use the same manifest of files across cluters. |
|
|
|
The cluster type is identified automatically and base configuration file is read from config/teams.yaml. |
|
Use the following environment variables to specify the cluster, team or configuration: |
|
|
|
AUDIOCRAFT_CLUSTER (optional): Cluster type to enforce. Useful if the cluster type |
|
cannot be inferred automatically. |
|
AUDIOCRAFT_CONFIG (optional): Path to yaml config holding the teams configuration. |
|
If not set, configuration is read from config/teams.yaml. |
|
AUDIOCRAFT_TEAM (optional): Name of the team. Recommended to set to your own team. |
|
Cluster configuration are shared across teams to match compute allocation, |
|
specify your cluster configuration in the configuration file under a key mapping |
|
your team name. |
|
""" |
|
_instance = None |
|
DEFAULT_TEAM = "default" |
|
|
|
def __init__(self) -> None: |
|
"""Loads configuration.""" |
|
self.team: str = os.getenv("AUDIOCRAFT_TEAM", self.DEFAULT_TEAM) |
|
cluster_type = _guess_cluster_type() |
|
cluster = os.getenv( |
|
"AUDIOCRAFT_CLUSTER", cluster_type.value |
|
) |
|
logger.info("Detecting cluster type %s", cluster_type) |
|
|
|
self.cluster: str = cluster |
|
|
|
config_path = os.getenv( |
|
"AUDIOCRAFT_CONFIG", |
|
Path(__file__) |
|
.parent.parent.joinpath("config/teams", self.team) |
|
.with_suffix(".yaml"), |
|
) |
|
self.config = omegaconf.OmegaConf.load(config_path) |
|
self._dataset_mappers = [] |
|
cluster_config = self._get_cluster_config() |
|
if "dataset_mappers" in cluster_config: |
|
for pattern, repl in cluster_config["dataset_mappers"].items(): |
|
regex = re.compile(pattern) |
|
self._dataset_mappers.append((regex, repl)) |
|
|
|
def _get_cluster_config(self) -> omegaconf.DictConfig: |
|
assert isinstance(self.config, omegaconf.DictConfig) |
|
return self.config[self.cluster] |
|
|
|
@classmethod |
|
def instance(cls): |
|
if cls._instance is None: |
|
cls._instance = cls() |
|
return cls._instance |
|
|
|
@classmethod |
|
def reset(cls): |
|
"""Clears the environment and forces a reload on next invocation.""" |
|
cls._instance = None |
|
|
|
@classmethod |
|
def get_team(cls) -> str: |
|
"""Gets the selected team as dictated by the AUDIOCRAFT_TEAM env var. |
|
If not defined, defaults to "labs". |
|
""" |
|
return cls.instance().team |
|
|
|
@classmethod |
|
def get_cluster(cls) -> str: |
|
"""Gets the detected cluster. |
|
This value can be overridden by the AUDIOCRAFT_CLUSTER env var. |
|
""" |
|
return cls.instance().cluster |
|
|
|
@classmethod |
|
def get_dora_dir(cls) -> Path: |
|
"""Gets the path to the dora directory for the current team and cluster. |
|
Value is overridden by the AUDIOCRAFT_DORA_DIR env var. |
|
""" |
|
cluster_config = cls.instance()._get_cluster_config() |
|
dora_dir = os.getenv("AUDIOCRAFT_DORA_DIR", cluster_config["dora_dir"]) |
|
logger.warning(f"Dora directory: {dora_dir}") |
|
return Path(dora_dir) |
|
|
|
@classmethod |
|
def get_reference_dir(cls) -> Path: |
|
"""Gets the path to the reference directory for the current team and cluster. |
|
Value is overridden by the AUDIOCRAFT_REFERENCE_DIR env var. |
|
""" |
|
cluster_config = cls.instance()._get_cluster_config() |
|
return Path(os.getenv("AUDIOCRAFT_REFERENCE_DIR", cluster_config["reference_dir"])) |
|
|
|
@classmethod |
|
def get_slurm_exclude(cls) -> tp.Optional[str]: |
|
"""Get the list of nodes to exclude for that cluster.""" |
|
cluster_config = cls.instance()._get_cluster_config() |
|
return cluster_config.get("slurm_exclude") |
|
|
|
@classmethod |
|
def get_slurm_partitions(cls, partition_types: tp.Optional[tp.List[str]] = None) -> str: |
|
"""Gets the requested partitions for the current team and cluster as a comma-separated string. |
|
|
|
Args: |
|
partition_types (list[str], optional): partition types to retrieve. Values must be |
|
from ['global', 'team']. If not provided, the global partition is returned. |
|
""" |
|
if not partition_types: |
|
partition_types = ["global"] |
|
|
|
cluster_config = cls.instance()._get_cluster_config() |
|
partitions = [ |
|
cluster_config["partitions"][partition_type] |
|
for partition_type in partition_types |
|
] |
|
return ",".join(partitions) |
|
|
|
@classmethod |
|
def resolve_reference_path(cls, path: tp.Union[str, Path]) -> Path: |
|
"""Converts reference placeholder in path with configured reference dir to resolve paths. |
|
|
|
Args: |
|
path (str or Path): Path to resolve. |
|
Returns: |
|
Path: Resolved path. |
|
""" |
|
path = str(path) |
|
|
|
if path.startswith("//reference"): |
|
reference_dir = cls.get_reference_dir() |
|
logger.warn(f"Reference directory: {reference_dir}") |
|
assert ( |
|
reference_dir.exists() and reference_dir.is_dir() |
|
), f"Reference directory does not exist: {reference_dir}." |
|
path = re.sub("^//reference", str(reference_dir), path) |
|
|
|
return Path(path) |
|
|
|
@classmethod |
|
def apply_dataset_mappers(cls, path: str) -> str: |
|
"""Applies dataset mapping regex rules as defined in the configuration. |
|
If no rules are defined, the path is returned as-is. |
|
""" |
|
instance = cls.instance() |
|
|
|
for pattern, repl in instance._dataset_mappers: |
|
path = pattern.sub(repl, path) |
|
|
|
return path |
|
|