|
|
|
|
|
import importlib.util |
|
|
import logging |
|
|
import os |
|
|
from contextlib import contextmanager |
|
|
from types import MethodType |
|
|
from typing import Optional |
|
|
|
|
|
from modelscope.utils.logger import get_logger as get_ms_logger |
|
|
|
|
|
|
|
|
|
|
|
def _is_local_master(): |
|
|
local_rank = int(os.getenv('LOCAL_RANK', -1)) |
|
|
return local_rank in {-1, 0} |
|
|
|
|
|
|
|
|
init_loggers = {} |
|
|
|
|
|
|
|
|
|
|
|
logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s') |
|
|
|
|
|
info_set = set() |
|
|
warning_set = set() |
|
|
|
|
|
|
|
|
def info_once(self, msg, *args, **kwargs): |
|
|
hash_id = kwargs.get('hash_id') or msg |
|
|
if hash_id in info_set: |
|
|
return |
|
|
info_set.add(hash_id) |
|
|
self.info(msg) |
|
|
|
|
|
|
|
|
def warning_once(self, msg, *args, **kwargs): |
|
|
hash_id = kwargs.get('hash_id') or msg |
|
|
if hash_id in warning_set: |
|
|
return |
|
|
warning_set.add(hash_id) |
|
|
self.warning(msg) |
|
|
|
|
|
|
|
|
def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'): |
|
|
""" Get logging logger |
|
|
|
|
|
Args: |
|
|
log_file: Log filename, if specified, file handler will be added to |
|
|
logger |
|
|
log_level: Logging level. |
|
|
file_mode: Specifies the mode to open the file, if filename is |
|
|
specified (if filemode is unspecified, it defaults to 'w'). |
|
|
""" |
|
|
if log_level is None: |
|
|
log_level = os.getenv('LOG_LEVEL', 'INFO').upper() |
|
|
log_level = getattr(logging, log_level, logging.INFO) |
|
|
logger_name = __name__.split('.')[0] |
|
|
logger = logging.getLogger(logger_name) |
|
|
logger.propagate = False |
|
|
if logger_name in init_loggers: |
|
|
add_file_handler_if_needed(logger, log_file, file_mode, log_level) |
|
|
return logger |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for handler in logger.root.handlers: |
|
|
if type(handler) is logging.StreamHandler: |
|
|
handler.setLevel(logging.ERROR) |
|
|
|
|
|
stream_handler = logging.StreamHandler() |
|
|
handlers = [stream_handler] |
|
|
|
|
|
is_worker0 = _is_local_master() |
|
|
|
|
|
if is_worker0 and log_file is not None: |
|
|
file_handler = logging.FileHandler(log_file, file_mode) |
|
|
handlers.append(file_handler) |
|
|
|
|
|
for handler in handlers: |
|
|
handler.setFormatter(logger_format) |
|
|
handler.setLevel(log_level) |
|
|
logger.addHandler(handler) |
|
|
|
|
|
if is_worker0: |
|
|
logger.setLevel(log_level) |
|
|
else: |
|
|
logger.setLevel(logging.ERROR) |
|
|
|
|
|
init_loggers[logger_name] = True |
|
|
|
|
|
logger.info_once = MethodType(info_once, logger) |
|
|
logger.warning_once = MethodType(warning_once, logger) |
|
|
return logger |
|
|
|
|
|
|
|
|
logger = get_logger() |
|
|
ms_logger = get_ms_logger() |
|
|
|
|
|
logger.handlers[0].setFormatter(logger_format) |
|
|
ms_logger.handlers[0].setFormatter(logger_format) |
|
|
log_level = os.getenv('LOG_LEVEL', 'INFO').upper() |
|
|
if _is_local_master(): |
|
|
ms_logger.setLevel(log_level) |
|
|
else: |
|
|
ms_logger.setLevel(logging.ERROR) |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def ms_logger_ignore_error(): |
|
|
ms_logger = get_ms_logger() |
|
|
origin_log_level = ms_logger.level |
|
|
ms_logger.setLevel(logging.CRITICAL) |
|
|
try: |
|
|
yield |
|
|
finally: |
|
|
ms_logger.setLevel(origin_log_level) |
|
|
|
|
|
|
|
|
def add_file_handler_if_needed(logger, log_file, file_mode, log_level): |
|
|
for handler in logger.handlers: |
|
|
if isinstance(handler, logging.FileHandler): |
|
|
return |
|
|
|
|
|
if importlib.util.find_spec('torch') is not None: |
|
|
is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0} |
|
|
else: |
|
|
is_worker0 = True |
|
|
|
|
|
if is_worker0 and log_file is not None: |
|
|
file_handler = logging.FileHandler(log_file, file_mode) |
|
|
file_handler.setFormatter(logger_format) |
|
|
file_handler.setLevel(log_level) |
|
|
logger.addHandler(file_handler) |
|
|
|