Spaces:
Sleeping
Sleeping
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Utility functions for SLURM configuration and cluster settings. | |
| """ | |
| from enum import Enum | |
| import os | |
| import socket | |
| import typing as tp | |
| import omegaconf | |
| class ClusterType(Enum): | |
| AWS = "aws" | |
| FAIR = "fair" | |
| RSC = "rsc" | |
| LOCAL_DARWIN = "darwin" | |
| DEFAULT = "default" # used for any other cluster. | |
| def _guess_cluster_type() -> ClusterType: | |
| uname = os.uname() | |
| fqdn = socket.getfqdn() | |
| if uname.sysname == "Linux" and (uname.release.endswith("-aws") or ".ec2" in fqdn): | |
| return ClusterType.AWS | |
| if fqdn.endswith(".fair"): | |
| return ClusterType.FAIR | |
| if fqdn.endswith(".facebook.com"): | |
| return ClusterType.RSC | |
| if uname.sysname == "Darwin": | |
| return ClusterType.LOCAL_DARWIN | |
| return ClusterType.DEFAULT | |
| def get_cluster_type( | |
| cluster_type: tp.Optional[ClusterType] = None, | |
| ) -> tp.Optional[ClusterType]: | |
| if cluster_type is None: | |
| return _guess_cluster_type() | |
| return cluster_type | |
| def get_slurm_parameters( | |
| cfg: omegaconf.DictConfig, cluster_type: tp.Optional[ClusterType] = None | |
| ) -> omegaconf.DictConfig: | |
| """Update SLURM parameters in configuration based on cluster type. | |
| If the cluster type is not specify, it infers it automatically. | |
| """ | |
| from ..environment import AudioCraftEnvironment | |
| cluster_type = get_cluster_type(cluster_type) | |
| # apply cluster-specific adjustments | |
| if cluster_type == ClusterType.AWS: | |
| cfg["mem_per_gpu"] = None | |
| cfg["constraint"] = None | |
| cfg["setup"] = [] | |
| elif cluster_type == ClusterType.RSC: | |
| cfg["mem_per_gpu"] = None | |
| cfg["setup"] = [] | |
| cfg["constraint"] = None | |
| cfg["partition"] = "learn" | |
| slurm_exclude = AudioCraftEnvironment.get_slurm_exclude() | |
| if slurm_exclude is not None: | |
| cfg["exclude"] = slurm_exclude | |
| return cfg | |