Student0809's picture
Add files using upload-large-folder tool
7feac49 verified
# Copyright (c) Alibaba, Inc. and its affiliates.
import importlib.util
import logging
import os
from contextlib import contextmanager
from types import MethodType
from typing import Optional
from modelscope.utils.logger import get_logger as get_ms_logger
# Avoid circular reference
def _is_local_master():
local_rank = int(os.getenv('LOCAL_RANK', -1))
return local_rank in {-1, 0}
init_loggers = {}
# old format
# formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger_format = logging.Formatter('[%(levelname)s:%(name)s] %(message)s')
info_set = set()
warning_set = set()
def info_once(self, msg, *args, **kwargs):
hash_id = kwargs.get('hash_id') or msg
if hash_id in info_set:
return
info_set.add(hash_id)
self.info(msg)
def warning_once(self, msg, *args, **kwargs):
hash_id = kwargs.get('hash_id') or msg
if hash_id in warning_set:
return
warning_set.add(hash_id)
self.warning(msg)
def get_logger(log_file: Optional[str] = None, log_level: Optional[int] = None, file_mode: str = 'w'):
""" Get logging logger
Args:
log_file: Log filename, if specified, file handler will be added to
logger
log_level: Logging level.
file_mode: Specifies the mode to open the file, if filename is
specified (if filemode is unspecified, it defaults to 'w').
"""
if log_level is None:
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
log_level = getattr(logging, log_level, logging.INFO)
logger_name = __name__.split('.')[0]
logger = logging.getLogger(logger_name)
logger.propagate = False
if logger_name in init_loggers:
add_file_handler_if_needed(logger, log_file, file_mode, log_level)
return logger
# handle duplicate logs to the console
# Starting in 1.8.0, PyTorch DDP attaches a StreamHandler <stderr> (NOTSET)
# to the root logger. As logger.propagate is True by default, this root
# level handler causes logging messages from rank>0 processes to
# unexpectedly show up on the console, creating much unwanted clutter.
# To fix this issue, we set the root logger's StreamHandler, if any, to log
# at the ERROR level.
for handler in logger.root.handlers:
if type(handler) is logging.StreamHandler:
handler.setLevel(logging.ERROR)
stream_handler = logging.StreamHandler()
handlers = [stream_handler]
is_worker0 = _is_local_master()
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
handlers.append(file_handler)
for handler in handlers:
handler.setFormatter(logger_format)
handler.setLevel(log_level)
logger.addHandler(handler)
if is_worker0:
logger.setLevel(log_level)
else:
logger.setLevel(logging.ERROR)
init_loggers[logger_name] = True
logger.info_once = MethodType(info_once, logger)
logger.warning_once = MethodType(warning_once, logger)
return logger
logger = get_logger()
ms_logger = get_ms_logger()
logger.handlers[0].setFormatter(logger_format)
ms_logger.handlers[0].setFormatter(logger_format)
log_level = os.getenv('LOG_LEVEL', 'INFO').upper()
if _is_local_master():
ms_logger.setLevel(log_level)
else:
ms_logger.setLevel(logging.ERROR)
@contextmanager
def ms_logger_ignore_error():
ms_logger = get_ms_logger()
origin_log_level = ms_logger.level
ms_logger.setLevel(logging.CRITICAL)
try:
yield
finally:
ms_logger.setLevel(origin_log_level)
def add_file_handler_if_needed(logger, log_file, file_mode, log_level):
for handler in logger.handlers:
if isinstance(handler, logging.FileHandler):
return
if importlib.util.find_spec('torch') is not None:
is_worker0 = int(os.getenv('LOCAL_RANK', -1)) in {-1, 0}
else:
is_worker0 = True
if is_worker0 and log_file is not None:
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(logger_format)
file_handler.setLevel(log_level)
logger.addHandler(file_handler)