LexaLCM_Pre0 / lcm /utils /logging.py
Lexa
Initial commit
3d79eb3
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
#
import os
import subprocess
from pathlib import Path
from typing import Dict
import torch.distributed as dist
from fairseq2.gang import get_rank
from fairseq2.logging import get_log_writer
from fairseq2.recipes.logging import _setup_aten_logging, _setup_nccl_logging
from fairseq2.recipes.utils.log import log_environment_info
from fairseq2.typing import Device
logger = get_log_writer(__name__)
LCM_REPOS = ["lcm", "fairseq2", "sonar", "stopes"]
def setup_additional_logging(log_folder: Path):
slurm_job_id: str = os.environ.get("SLURM_JOB_ID", "local")
base_log_file = log_folder / f"{slurm_job_id}_{get_rank()}.log"
_setup_aten_logging(base_log_file, force=False)
_setup_nccl_logging(base_log_file, force=False)
def log_git_status(
repo: str = "lcm",
tolerate_uncommitted: bool = False,
) -> str:
assert repo in LCM_REPOS, (
f"Only the LCM core repos ({LCM_REPOS}) are supported in `log_git_status`"
)
repo_path = os.path.dirname(globals()[repo].__file__)
try:
# check for modifications
mod_output = subprocess.run(
f"cd {repo_path}; git status --porcelain", capture_output=True, shell=True
)
modifications = mod_output.stdout.decode("utf-8").split("\n")
uncommitted = len(
[
m
for m in modifications
if m.startswith(" M") or m.startswith(("M ", "A ", "D ", "R "))
]
)
if uncommitted > 0:
if tolerate_uncommitted:
logger.warning(
(
"Changes to {} should be committed before running a job "
"- found {} change(s)."
" We will continue regardless, but the git commit hashes are unreliable!"
).format(repo, uncommitted)
)
else:
raise AssertionError(
f"Changes to {repo} should be committed before running a job - found {uncommitted} change(s). If runing tests try adding `--debug-training`"
)
# get commit hash
output = subprocess.run(
f"cd {repo_path}; git rev-parse HEAD", capture_output=True, shell=True
)
commit_hash = output.stdout.decode("ascii").strip()
logger.info(f"{repo} ({repo_path}) commit hash: {commit_hash}")
return commit_hash
except AssertionError:
raise
except BaseException:
raise ValueError(
f"Could not check the git revision hash, make sure you can run `git status` in {repo} ({repo_path})"
)
def log_lcm_environment(tolerate_uncommitted: bool = False) -> Dict:
"""
For traceability and reproducibility, get the latest commit hash for the four key repos
"""
commit_hashes = {
repo: log_git_status(repo, tolerate_uncommitted) for repo in LCM_REPOS
}
return commit_hashes
def log_env_variables(device: Device) -> None:
"""Log environment variables useful for debugging, including
fs2's `log_environment_info` to dump Fairseq2, torch, nccl and other relevant metadata
"""
for key in sorted(os.environ.keys()):
if not (
key.startswith(
("SLURM_", "SUBMITIT_", "NCCL_", "FI_", "CUDA_", "FAIRSEQ2_", "TORCH_")
)
or key
in (
"MASTER_ADDR",
"MASTER_PORT",
"RANK",
"WORLD_SIZE",
"LOCAL_RANK",
"LOCAL_WORLD_SIZE",
)
):
continue
value = os.environ[key]
logger.info(f"R{dist.get_rank()} -- {key}={value}")
# For Fairseq2, torch and devices
log_environment_info(logger, device)