|
import os |
|
import re |
|
import numpy as np |
|
import logging |
|
|
|
logs = set() |
|
|
|
|
|
def init_log(name, level=logging.INFO): |
|
if (name, level) in logs: |
|
return |
|
logs.add((name, level)) |
|
logger = logging.getLogger(name) |
|
logger.setLevel(level) |
|
ch = logging.StreamHandler() |
|
ch.setLevel(level) |
|
if "SLURM_PROCID" in os.environ: |
|
rank = int(os.environ["SLURM_PROCID"]) |
|
logger.addFilter(lambda record: rank == 0) |
|
else: |
|
rank = 0 |
|
format_str = "[%(asctime)s][%(levelname)8s] %(message)s" |
|
formatter = logging.Formatter(format_str) |
|
ch.setFormatter(formatter) |
|
logger.addHandler(ch) |
|
return logger |
|
|