import os import re import numpy as np import logging logs = set() def init_log(name, level=logging.INFO): if (name, level) in logs: return logs.add((name, level)) logger = logging.getLogger(name) logger.setLevel(level) ch = logging.StreamHandler() ch.setLevel(level) if "SLURM_PROCID" in os.environ: rank = int(os.environ["SLURM_PROCID"]) logger.addFilter(lambda record: rank == 0) else: rank = 0 format_str = "[%(asctime)s][%(levelname)8s] %(message)s" formatter = logging.Formatter(format_str) ch.setFormatter(formatter) logger.addHandler(ch) return logger