Spaces:
Runtime error
Runtime error
#!/usr/bin/env python3 | |
# -*- coding:utf-8 -*- | |
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved. | |
from loguru import logger | |
import inspect | |
import os | |
import sys | |
def get_caller_name(depth=0): | |
""" | |
Args: | |
depth (int): Depth of caller conext, use 0 for caller depth. Default value: 0. | |
Returns: | |
str: module name of the caller | |
""" | |
# the following logic is a little bit faster than inspect.stack() logic | |
frame = inspect.currentframe().f_back | |
for _ in range(depth): | |
frame = frame.f_back | |
return frame.f_globals["__name__"] | |
class StreamToLoguru: | |
""" | |
stream object that redirects writes to a logger instance. | |
""" | |
def __init__(self, level="INFO", caller_names=("apex", "pycocotools")): | |
""" | |
Args: | |
level(str): log level string of loguru. Default value: "INFO". | |
caller_names(tuple): caller names of redirected module. | |
Default value: (apex, pycocotools). | |
""" | |
self.level = level | |
self.linebuf = "" | |
self.caller_names = caller_names | |
def write(self, buf): | |
full_name = get_caller_name(depth=1) | |
module_name = full_name.rsplit(".", maxsplit=-1)[0] | |
if module_name in self.caller_names: | |
for line in buf.rstrip().splitlines(): | |
# use caller level log | |
logger.opt(depth=2).log(self.level, line.rstrip()) | |
else: | |
sys.__stdout__.write(buf) | |
def flush(self): | |
pass | |
def redirect_sys_output(log_level="INFO"): | |
redirect_logger = StreamToLoguru(log_level) | |
sys.stderr = redirect_logger | |
sys.stdout = redirect_logger | |
def setup_logger(save_dir, distributed_rank=0, filename="log.txt", mode="a"): | |
"""setup logger for training and testing. | |
Args: | |
save_dir(str): location to save log file | |
distributed_rank(int): device rank when multi-gpu environment | |
filename (string): log save name. | |
mode(str): log file write mode, `append` or `override`. default is `a`. | |
Return: | |
logger instance. | |
""" | |
loguru_format = ( | |
"<green>{time:YYYY-MM-DD HH:mm:ss}</green> | " | |
"<level>{level: <8}</level> | " | |
"<cyan>{name}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>" | |
) | |
logger.remove() | |
save_file = os.path.join(save_dir, filename) | |
if mode == "o" and os.path.exists(save_file): | |
os.remove(save_file) | |
# only keep logger in rank0 process | |
if distributed_rank == 0: | |
logger.add( | |
sys.stderr, | |
format=loguru_format, | |
level="INFO", | |
enqueue=True, | |
) | |
logger.add(save_file) | |
# redirect stdout/stderr to loguru | |
redirect_sys_output("INFO") | |