Spaces:
Runtime error
Runtime error
File size: 1,761 Bytes
f7ac35e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 |
# Copyright (c) OpenMMLab. All rights reserved.
from ...dist_utils import master_only
from ..hook import HOOKS
from .base import LoggerHook
@HOOKS.register_module()
class DvcliveLoggerHook(LoggerHook):
"""Class to log metrics with dvclive.
It requires `dvclive`_ to be installed.
Args:
path (str): Directory where dvclive will write TSV log files.
interval (int): Logging interval (every k iterations).
Default 10.
ignore_last (bool): Ignore the log of last iterations in each epoch
if less than `interval`.
Default: True.
reset_flag (bool): Whether to clear the output buffer after logging.
Default: True.
by_epoch (bool): Whether EpochBasedRunner is used.
Default: True.
.. _dvclive:
https://dvc.org/doc/dvclive
"""
def __init__(self,
path,
interval=10,
ignore_last=True,
reset_flag=True,
by_epoch=True):
super(DvcliveLoggerHook, self).__init__(interval, ignore_last,
reset_flag, by_epoch)
self.path = path
self.import_dvclive()
def import_dvclive(self):
try:
import dvclive
except ImportError:
raise ImportError(
'Please run "pip install dvclive" to install dvclive')
self.dvclive = dvclive
@master_only
def before_run(self, runner):
self.dvclive.init(self.path)
@master_only
def log(self, runner):
tags = self.get_loggable_tags(runner)
if tags:
for k, v in tags.items():
self.dvclive.log(k, v, step=self.get_iter(runner))
|