Spaces:
Runtime error
Runtime error
File size: 3,205 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import logging
import warnings
from rich.console import Console
from rich.theme import Theme
from rich.pretty import install as pretty_install
from rich.traceback import install as traceback_install
from installer import log as installer_log, setup_logging
setup_logging()
log = installer_log
console = Console(log_time=True, tab_size=4, log_time_format='%H:%M:%S-%f', soft_wrap=True, safe_box=True, theme=Theme({
"traceback.border": "black",
"traceback.border.syntax_error": "black",
"inspect.value.border": "black",
}))
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False)
already_displayed = {}
def install(suppress=[]):
warnings.filterwarnings("ignore", category=UserWarning)
pretty_install(console=console)
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=suppress)
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s')
# for handler in logging.getLogger().handlers:
# handler.setLevel(logging.INFO)
def print_error_explanation(message):
lines = message.strip().split("\n")
for line in lines:
log.error(line)
def display(e: Exception, task, suppress=[]):
log.error(f"{task or 'error'}: {type(e).__name__}")
console.print_exception(show_locals=False, max_frames=10, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width)
def display_once(e: Exception, task):
if task in already_displayed:
return
display(e, task)
already_displayed[task] = 1
def run(code, task):
try:
code()
except Exception as e:
display(e, task)
def exception(suppress=[]):
console.print_exception(show_locals=False, max_frames=10, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200]))
def profile(profiler, msg: str):
profiler.disable()
import io
import pstats
stream = io.StringIO() # pylint: disable=abstract-class-instantiated
p = pstats.Stats(profiler, stream=stream)
p.sort_stats(pstats.SortKey.CUMULATIVE)
p.print_stats(100)
# p.print_title()
# p.print_call_heading(10, 'time')
# p.print_callees(10)
# p.print_callers(10)
profiler = None
lines = stream.getvalue().split('\n')
lines = [x for x in lines if '<frozen' not in x
and '{built-in' not in x
and '/logging' not in x
and 'Ordered by' not in x
and 'List reduced' not in x
and '_lsprof' not in x
and '/profiler' not in x
and 'rich' not in x
and x.strip() != ''
]
txt = '\n'.join(lines[:min(5, len(lines))])
log.debug(f'Profile {msg}: {txt}')
def profile_torch(profiler, msg: str):
profiler.stop()
lines = profiler.key_averages().table(sort_by="self_cpu_time_total", row_limit=12)
lines = lines.split('\n')
lines = [x for x in lines if '/profiler' not in x and '---' not in x]
txt = '\n'.join(lines)
log.debug(f'Torch profile {msg}: \n{txt}')
|