File size: 2,998 Bytes
5e88f62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import functools
import logging
from pathlib import Path
import shutil
import glob

from datetime import datetime, timedelta

_LOG_DICT = {}

@functools.lru_cache(None)  # always the same :)
def get_datestring_for_the_run():
    return datetime.now().strftime("%Y%m%d_%H%M%S")

def _make_key(msg, args, kwargs):
    args_str = ', '.join([str(arg) for arg in args])
    kwargs_str = ', '.join([f'{str(k)}={str(v)}' for k, v in kwargs.items()])
    r = [msg]
    if args_str or kwargs_str:
        r.append(' % (')
    r.append(args_str)
    if args_str:
        r.append(', ')
    r.append(kwargs_str)
    if args_str or kwargs_str:
        r.append(')')
    # MyMessage % (arg1, arg2, kw1=v1m, kw2=v2m)
    return ''.join(r)


def debug_once(msg, *args, logger=None, **kwargs):
    key = _make_key(msg, args, kwargs)

    lvl = logging.DEBUG
    t = datetime.now()
    should_log = True

    if key in _LOG_DICT:
        plvl, pt = _LOG_DICT[key]
        # Do not overwrite
        if plvl > lvl:
            t = pt
            lvl = plvl
        should_log = False

    _LOG_DICT[key] = (lvl, t)
    if should_log:
        logger.debug(msg, *args, **kwargs)


def info_once(msg, *args, logger=None, **kwargs):
    key = _make_key(msg, args, kwargs)

    lvl = logging.INFO
    t = datetime.now()
    should_log = True

    if key in _LOG_DICT:
        plvl, pt = _LOG_DICT[key]
        should_log = plvl <= lvl and t - pt > timedelta(minutes=5)
        lvl = max(lvl, plvl)

    _LOG_DICT[key] = (lvl, t)
    if should_log:
        logger.info(msg, *args, **kwargs)


def getLogger(name):
    if name != 'gwm' and not name.startswith('gwm.'):
        name = 'gwm.' + name
    logger = logging.getLogger(name)
    logger.info_once = functools.partial(info_once, logger=logger)
    logger.debug_once = functools.partial(info_once, logger=logger)
    return logger


def checkpoint_code(log_path):
    code_path = Path(log_path) / 'code'
    if code_path.exists():
        code_path = code_path.with_name(f'code_{get_datestring_for_the_run()}')
    code_path.mkdir(parents=True, exist_ok=True)
    for file in glob.glob('*.py'):
        shutil.copy(file, code_path)
    shutil.copytree('datasets', code_path / 'datasets', ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
    shutil.copytree('losses', code_path / 'losses', ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
    shutil.copytree('utils', code_path / 'utils', ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))
    shutil.copytree('mask_former', code_path / 'mask_former', ignore=shutil.ignore_patterns('*.pyc', '__pycache__'))


class log_level:
    def __init__(self, logger, lvl=logging.INFO):
        self.logger = logging.getLogger(logger)
        self.lvl = lvl
        self.current_lvl = self.logger.level

    def __enter__(self):
        self.current_lvl = self.logger.level
        self.logger.setLevel(self.lvl)

    def __exit__(self, exc_type, exc_val, exc_tb):
        self.logger.setLevel(self.current_lvl)