Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| from enum import Enum | |
| import os | |
| from pathlib import Path | |
| from typing import Any, Dict, Optional | |
| class ClusterType(Enum): | |
| AWS = "aws" | |
| FAIR = "fair" | |
| RSC = "rsc" | |
| def _guess_cluster_type() -> ClusterType: | |
| uname = os.uname() | |
| if uname.sysname == "Linux": | |
| if uname.release.endswith("-aws"): | |
| # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" | |
| return ClusterType.AWS | |
| elif uname.nodename.startswith("rsc"): | |
| # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" | |
| return ClusterType.RSC | |
| return ClusterType.FAIR | |
| def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: | |
| if cluster_type is None: | |
| return _guess_cluster_type() | |
| return cluster_type | |
| def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: | |
| cluster_type = get_cluster_type(cluster_type) | |
| if cluster_type is None: | |
| return None | |
| CHECKPOINT_DIRNAMES = { | |
| ClusterType.AWS: "checkpoints", | |
| ClusterType.FAIR: "checkpoint", | |
| ClusterType.RSC: "checkpoint/dino", | |
| } | |
| return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] | |
| def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: | |
| checkpoint_path = get_checkpoint_path(cluster_type) | |
| if checkpoint_path is None: | |
| return None | |
| username = os.environ.get("USER") | |
| assert username is not None | |
| return checkpoint_path / username | |
| def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: | |
| cluster_type = get_cluster_type(cluster_type) | |
| if cluster_type is None: | |
| return None | |
| SLURM_PARTITIONS = { | |
| ClusterType.AWS: "learnlab", | |
| ClusterType.FAIR: "learnlab", | |
| ClusterType.RSC: "learn", | |
| } | |
| return SLURM_PARTITIONS[cluster_type] | |
| def get_slurm_executor_parameters( | |
| nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs | |
| ) -> Dict[str, Any]: | |
| # create default parameters | |
| params = { | |
| "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html | |
| "gpus_per_node": num_gpus_per_node, | |
| "tasks_per_node": num_gpus_per_node, # one task per GPU | |
| "cpus_per_task": 10, | |
| "nodes": nodes, | |
| "slurm_partition": get_slurm_partition(cluster_type), | |
| } | |
| # apply cluster-specific adjustments | |
| cluster_type = get_cluster_type(cluster_type) | |
| if cluster_type == ClusterType.AWS: | |
| params["cpus_per_task"] = 12 | |
| del params["mem_gb"] | |
| elif cluster_type == ClusterType.RSC: | |
| params["cpus_per_task"] = 12 | |
| # set additional parameters / apply overrides | |
| params.update(kwargs) | |
| return params | |