File size: 4,257 Bytes
fc16538
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
# TRI-VIDAR - Copyright 2022 Toyota Research Institute.  All rights reserved.

import argparse
import os
from functools import partial

from termcolor import colored

from vidar.utils.distributed import on_rank_0


def pcolor(string, color, on_color=None, attrs=None):
    """
    Produces a colored string for printing

    Parameters
    ----------
    string : String
        String that will be colored
    color : String
        Color to use
    on_color : String
        Background color to use
    attrs : list[String]
        Different attributes for the string

    Returns
    -------
    string: String
        Colored string
    """
    return colored(string, color, on_color, attrs)


@on_rank_0
def print_config(config):
    """
    Prints header for model configuration

    Parameters
    ----------
    config : Config
        Model configuration
    """
    header_colors = {
        0: ('red', ('bold', 'dark')),
        1: ('cyan', ('bold','dark')),
        2: ('green', ('bold', 'dark')),
        3: ('green', ('bold', 'dark')),
    }
    line_colors = ('blue', ())

    # Recursive print function
    def print_recursive(rec_args, pad=3, level=0):
        # if level == 0:
        #     print(pcolor('config:',
        #                  color=header_colors[level][0],
        #                  attrs=header_colors[level][1]))
        for key, val in rec_args.__dict__.items():
            if isinstance(val, argparse.Namespace):
                print(pcolor('{} {}:'.format('-' * pad, key),
                             color=header_colors[level][0],
                             attrs=header_colors[level][1]))
                print_recursive(val, pad + 2, level + 1)
            else:
                print('{}: {}'.format(pcolor('{} {}'.format('-' * pad, key),
                                             color=line_colors[0],
                                             attrs=line_colors[1]), val))

    # Color partial functions
    pcolor1 = partial(pcolor, color='blue', attrs=('bold', 'dark'))
    pcolor2 = partial(pcolor, color='blue', attrs=('bold',))
    # Config and name
    line = pcolor1('#' * 120)
    # if 'default' in config.__dict__.keys():
    #     path = pcolor1('### Config: ') + \
    #            pcolor2('{}'.format(config.default.replace('/', '.'))) + \
    #            pcolor1(' -> ') + \
    #            pcolor2('{}'.format(config.config.replace('/', '.')))
    # if 'name' in config.__dict__.keys():
    #     name = pcolor1('### Name: ') + \
    #            pcolor2('{}'.format(config.name))
    #     # Add wandb link if available
    #     if not config.wandb.dry_run:
    #         name += pcolor1(' -> ') + \
    #                 pcolor2('{}'.format(config.wandb.url))
    #     # Add s3 link if available
    #     if config.checkpoint.s3_path is not '':
    #         name += pcolor1('\n### s3:') + \
    #                 pcolor2(' {}'.format(config.checkpoint.s3_url))
    #     # # Create header string
    #     # header = '%s\n%s\n%s\n%s' % (line, path, name, line)

    # Print header, config and header again
    print()
    # print(header)
    print_recursive(config)
    # print(header)
    print()


def set_debug(debug):
    """
    Enable or disable debug terminal logging

    Parameters
    ----------
    debug : Bool
        Debugging flag (True to enable)
    """
    # Disable logging if requested
    if not debug:
        os.environ['NCCL_DEBUG'] = ''
        os.environ['WANDB_SILENT'] = 'true'
        # warnings.filterwarnings("ignore")
        # logging.disable(logging.CRITICAL)


class AvgMeter:
    """Average meter for logging"""
    def __init__(self, n_max=100):
        self.n_max = n_max
        self.values = []

    def __call__(self, value):
        """Append new value and returns average"""
        self.values.append(value)
        if len(self.values) > self.n_max:
            self.values.pop(0)
        return self.get()

    def get(self):
        """Get current average"""
        return sum(self.values) / len(self.values)

    def reset(self):
        """Reset meter"""
        self.values.clear()

    def get_and_reset(self):
        """Get current average and reset"""
        average = self.get()
        self.reset()
        return average