vits-simple-api-gsv / logger.py
Artrajz's picture
init
960cd20
raw
history blame contribute delete
No virus
3.7 kB
import os
import re
import sys
import logging
import logzero
import warnings
from contants import config
from logging.handlers import TimedRotatingFileHandler
# 过滤警告
class SpecificWarningFilter(logging.Filter):
def __init__(self, warning_messages):
super().__init__()
self.warning_messages = warning_messages
def filter(self, record):
return all(msg not in record.getMessage() for msg in self.warning_messages)
# 过滤警告
ignore_warning_messages = ["stft with return_complex=False is deprecated",
"1Torch was not compiled with flash attention",
"torch.nn.utils.weight_norm is deprecated",
"Some weights of the model checkpoint.*were not used.*initializing.*from the checkpoint.*",
]
for message in ignore_warning_messages:
warnings.filterwarnings(action="ignore", message=message)
class WarningFilter(logging.Filter):
def filter(self, record):
return record.levelno != logging.WARNING
logzero.loglevel(logging.WARNING)
logger = logging.getLogger("vits-simple-api")
level = config.log_config.logging_level.upper()
level_dict = {'DEBUG': logging.DEBUG, 'INFO': logging.INFO, 'WARNING': logging.WARNING, 'ERROR': logging.ERROR,
'CRITICAL': logging.CRITICAL}
logging.getLogger().setLevel(level_dict[level])
# formatter = logging.Formatter('%(levelname)s:%(name)s %(message)s')
# formatter = logging.Formatter('%(asctime)s [%(levelname)s] [%(module)s.%(funcName)s:%(lineno)d] %(message)s',
# datefmt='%Y-%m-%d %H:%M:%S')
# 根据日志级别选择日志格式
if level == "DEBUG":
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]',
datefmt='%Y-%m-%d %H:%M:%S')
elif level == "INFO":
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s',
datefmt='%Y-%m-%d %H:%M:%S')
else:
# 如果日志级别既不是DEBUG也不是INFO,则使用默认的日志格式
formatter = logging.Formatter('%(asctime)s [%(levelname)s] %(message)s [in %(module)s.%(funcName)s:%(lineno)d]',
datefmt='%Y-%m-%d %H:%M:%S')
logs_path = os.path.join(config.abs_path, config.log_config.logs_path)
os.makedirs(logs_path, exist_ok=True)
log_file = os.path.join(logs_path, 'latest.log')
backup_count = getattr(config, "LOGS_BACKUPCOUNT", 30)
handler = TimedRotatingFileHandler(log_file, when="midnight", interval=1, backupCount=backup_count, encoding='utf-8')
handler.suffix = "%Y-%m-%d.log"
handler.setFormatter(formatter)
# remove all handlers (remove StreamHandler handle)
logging.getLogger().handlers = []
logging.getLogger().addHandler(handler)
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logging.getLogger().addHandler(console_handler)
logging.getLogger('numba').setLevel(logging.WARNING)
logging.getLogger("langid.langid").setLevel(logging.INFO)
logging.getLogger("apscheduler.scheduler").setLevel(logging.INFO)
logging.getLogger("transformers.modeling_utils").addFilter(WarningFilter())
# Custom function to handle uncaught exceptions
def handle_exception(exc_type, exc_value, exc_traceback):
# If it's a keyboard interrupt, don't handle it, just return
if issubclass(exc_type, KeyboardInterrupt):
sys.__excepthook__(exc_type, exc_value, exc_traceback)
return
logger.error("Uncaught exception", exc_info=(exc_type, exc_value, exc_traceback))
# Set the global exception handler in Python
sys.excepthook = handle_exception