Spaces:
Runtime error
Runtime error
File size: 4,257 Bytes
fc16538 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
# TRI-VIDAR - Copyright 2022 Toyota Research Institute. All rights reserved.
import argparse
import os
from functools import partial
from termcolor import colored
from vidar.utils.distributed import on_rank_0
def pcolor(string, color, on_color=None, attrs=None):
"""
Produces a colored string for printing
Parameters
----------
string : String
String that will be colored
color : String
Color to use
on_color : String
Background color to use
attrs : list[String]
Different attributes for the string
Returns
-------
string: String
Colored string
"""
return colored(string, color, on_color, attrs)
@on_rank_0
def print_config(config):
"""
Prints header for model configuration
Parameters
----------
config : Config
Model configuration
"""
header_colors = {
0: ('red', ('bold', 'dark')),
1: ('cyan', ('bold','dark')),
2: ('green', ('bold', 'dark')),
3: ('green', ('bold', 'dark')),
}
line_colors = ('blue', ())
# Recursive print function
def print_recursive(rec_args, pad=3, level=0):
# if level == 0:
# print(pcolor('config:',
# color=header_colors[level][0],
# attrs=header_colors[level][1]))
for key, val in rec_args.__dict__.items():
if isinstance(val, argparse.Namespace):
print(pcolor('{} {}:'.format('-' * pad, key),
color=header_colors[level][0],
attrs=header_colors[level][1]))
print_recursive(val, pad + 2, level + 1)
else:
print('{}: {}'.format(pcolor('{} {}'.format('-' * pad, key),
color=line_colors[0],
attrs=line_colors[1]), val))
# Color partial functions
pcolor1 = partial(pcolor, color='blue', attrs=('bold', 'dark'))
pcolor2 = partial(pcolor, color='blue', attrs=('bold',))
# Config and name
line = pcolor1('#' * 120)
# if 'default' in config.__dict__.keys():
# path = pcolor1('### Config: ') + \
# pcolor2('{}'.format(config.default.replace('/', '.'))) + \
# pcolor1(' -> ') + \
# pcolor2('{}'.format(config.config.replace('/', '.')))
# if 'name' in config.__dict__.keys():
# name = pcolor1('### Name: ') + \
# pcolor2('{}'.format(config.name))
# # Add wandb link if available
# if not config.wandb.dry_run:
# name += pcolor1(' -> ') + \
# pcolor2('{}'.format(config.wandb.url))
# # Add s3 link if available
# if config.checkpoint.s3_path is not '':
# name += pcolor1('\n### s3:') + \
# pcolor2(' {}'.format(config.checkpoint.s3_url))
# # # Create header string
# # header = '%s\n%s\n%s\n%s' % (line, path, name, line)
# Print header, config and header again
print()
# print(header)
print_recursive(config)
# print(header)
print()
def set_debug(debug):
"""
Enable or disable debug terminal logging
Parameters
----------
debug : Bool
Debugging flag (True to enable)
"""
# Disable logging if requested
if not debug:
os.environ['NCCL_DEBUG'] = ''
os.environ['WANDB_SILENT'] = 'true'
# warnings.filterwarnings("ignore")
# logging.disable(logging.CRITICAL)
class AvgMeter:
"""Average meter for logging"""
def __init__(self, n_max=100):
self.n_max = n_max
self.values = []
def __call__(self, value):
"""Append new value and returns average"""
self.values.append(value)
if len(self.values) > self.n_max:
self.values.pop(0)
return self.get()
def get(self):
"""Get current average"""
return sum(self.values) / len(self.values)
def reset(self):
"""Reset meter"""
self.values.clear()
def get_and_reset(self):
"""Get current average and reset"""
average = self.get()
self.reset()
return average
|