Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from ...dist_utils import master_only | |
from ..hook import HOOKS | |
from .base import LoggerHook | |
class NeptuneLoggerHook(LoggerHook): | |
"""Class to log metrics to NeptuneAI. | |
It requires `neptune-client` to be installed. | |
Args: | |
init_kwargs (dict): a dict contains the initialization keys as below: | |
- project (str): Name of a project in a form of | |
namespace/project_name. If None, the value of | |
NEPTUNE_PROJECT environment variable will be taken. | |
- api_token (str): User’s API token. | |
If None, the value of NEPTUNE_API_TOKEN environment | |
variable will be taken. Note: It is strongly recommended | |
to use NEPTUNE_API_TOKEN environment variable rather than | |
placing your API token in plain text in your source code. | |
- name (str, optional, default is 'Untitled'): Editable name of | |
the run. Name is displayed in the run's Details and in | |
Runs table as a column. | |
Check https://docs.neptune.ai/api-reference/neptune#init for | |
more init arguments. | |
interval (int): Logging interval (every k iterations). | |
ignore_last (bool): Ignore the log of last iterations in each epoch | |
if less than `interval`. | |
reset_flag (bool): Whether to clear the output buffer after logging | |
by_epoch (bool): Whether EpochBasedRunner is used. | |
.. _NeptuneAI: | |
https://docs.neptune.ai/you-should-know/logging-metadata | |
""" | |
def __init__(self, | |
init_kwargs=None, | |
interval=10, | |
ignore_last=True, | |
reset_flag=True, | |
with_step=True, | |
by_epoch=True): | |
super(NeptuneLoggerHook, self).__init__(interval, ignore_last, | |
reset_flag, by_epoch) | |
self.import_neptune() | |
self.init_kwargs = init_kwargs | |
self.with_step = with_step | |
def import_neptune(self): | |
try: | |
import neptune.new as neptune | |
except ImportError: | |
raise ImportError( | |
'Please run "pip install neptune-client" to install neptune') | |
self.neptune = neptune | |
self.run = None | |
def before_run(self, runner): | |
if self.init_kwargs: | |
self.run = self.neptune.init(**self.init_kwargs) | |
else: | |
self.run = self.neptune.init() | |
def log(self, runner): | |
tags = self.get_loggable_tags(runner) | |
if tags: | |
for tag_name, tag_value in tags.items(): | |
if self.with_step: | |
self.run[tag_name].log( | |
tag_value, step=self.get_iter(runner)) | |
else: | |
tags['global_step'] = self.get_iter(runner) | |
self.run[tag_name].log(tags) | |
def after_run(self, runner): | |
self.run.stop() | |