Spaces:
Runtime error
Runtime error
import logging | |
import warnings | |
from rich.console import Console | |
from rich.theme import Theme | |
from rich.pretty import install as pretty_install | |
from rich.traceback import install as traceback_install | |
from installer import log as installer_log, setup_logging | |
setup_logging() | |
log = installer_log | |
console = Console(log_time=True, tab_size=4, log_time_format='%H:%M:%S-%f', soft_wrap=True, safe_box=True, theme=Theme({ | |
"traceback.border": "black", | |
"traceback.border.syntax_error": "black", | |
"inspect.value.border": "black", | |
})) | |
pretty_install(console=console) | |
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False) | |
already_displayed = {} | |
def install(suppress=[]): | |
warnings.filterwarnings("ignore", category=UserWarning) | |
pretty_install(console=console) | |
traceback_install(console=console, extra_lines=1, width=console.width, word_wrap=False, indent_guides=False, suppress=suppress) | |
logging.basicConfig(level=logging.ERROR, format='%(asctime)s | %(levelname)s | %(pathname)s | %(message)s') | |
# for handler in logging.getLogger().handlers: | |
# handler.setLevel(logging.INFO) | |
def print_error_explanation(message): | |
lines = message.strip().split("\n") | |
for line in lines: | |
log.error(line) | |
def display(e: Exception, task, suppress=[]): | |
log.error(f"{task or 'error'}: {type(e).__name__}") | |
console.print_exception(show_locals=False, max_frames=10, extra_lines=1, suppress=suppress, theme="ansi_dark", word_wrap=False, width=console.width) | |
def display_once(e: Exception, task): | |
if task in already_displayed: | |
return | |
display(e, task) | |
already_displayed[task] = 1 | |
def run(code, task): | |
try: | |
code() | |
except Exception as e: | |
display(e, task) | |
def exception(suppress=[]): | |
console.print_exception(show_locals=False, max_frames=10, extra_lines=2, suppress=suppress, theme="ansi_dark", word_wrap=False, width=min([console.width, 200])) | |
def profile(profiler, msg: str): | |
profiler.disable() | |
import io | |
import pstats | |
stream = io.StringIO() # pylint: disable=abstract-class-instantiated | |
p = pstats.Stats(profiler, stream=stream) | |
p.sort_stats(pstats.SortKey.CUMULATIVE) | |
p.print_stats(100) | |
# p.print_title() | |
# p.print_call_heading(10, 'time') | |
# p.print_callees(10) | |
# p.print_callers(10) | |
profiler = None | |
lines = stream.getvalue().split('\n') | |
lines = [x for x in lines if '<frozen' not in x | |
and '{built-in' not in x | |
and '/logging' not in x | |
and 'Ordered by' not in x | |
and 'List reduced' not in x | |
and '_lsprof' not in x | |
and '/profiler' not in x | |
and 'rich' not in x | |
and x.strip() != '' | |
] | |
txt = '\n'.join(lines[:min(5, len(lines))]) | |
log.debug(f'Profile {msg}: {txt}') | |
def profile_torch(profiler, msg: str): | |
profiler.stop() | |
lines = profiler.key_averages().table(sort_by="self_cpu_time_total", row_limit=12) | |
lines = lines.split('\n') | |
lines = [x for x in lines if '/profiler' not in x and '---' not in x] | |
txt = '\n'.join(lines) | |
log.debug(f'Torch profile {msg}: \n{txt}') | |