Spaces:
Runtime error
Runtime error
import os | |
from .base_logger import BaseLogger | |
class WandbLogger(BaseLogger): | |
def __init__(self, | |
project=None, | |
name=None, | |
id=None, | |
entity=None, | |
save_dir=None, | |
config=None, | |
**kwargs): | |
try: | |
import wandb | |
self.wandb = wandb | |
except ModuleNotFoundError: | |
raise ModuleNotFoundError( | |
"Please install wandb using `pip install wandb`" | |
) | |
self.project = project | |
self.name = name | |
self.id = id | |
self.save_dir = save_dir | |
self.config = config | |
self.kwargs = kwargs | |
self.entity = entity | |
self._run = None | |
self._wandb_init = dict( | |
project=self.project, | |
name=self.name, | |
id=self.id, | |
entity=self.entity, | |
dir=self.save_dir, | |
resume="allow" | |
) | |
self._wandb_init.update(**kwargs) | |
_ = self.run | |
if self.config: | |
self.run.config.update(self.config) | |
def run(self): | |
if self._run is None: | |
if self.wandb.run is not None: | |
logger.info( | |
"There is a wandb run already in progress " | |
"and newly created instances of `WandbLogger` will reuse" | |
" this run. If this is not desired, call `wandb.finish()`" | |
"before instantiating `WandbLogger`." | |
) | |
self._run = self.wandb.run | |
else: | |
self._run = self.wandb.init(**self._wandb_init) | |
return self._run | |
def log_metrics(self, metrics, prefix=None, step=None): | |
if not prefix: | |
prefix = "" | |
updated_metrics = {prefix.lower() + "/" + k: v for k, v in metrics.items()} | |
self.run.log(updated_metrics, step=step) | |
def log_model(self, is_best, prefix, metadata=None): | |
model_path = os.path.join(self.save_dir, prefix + '.pdparams') | |
artifact = self.wandb.Artifact('model-{}'.format(self.run.id), type='model', metadata=metadata) | |
artifact.add_file(model_path, name="model_ckpt.pdparams") | |
aliases = [prefix] | |
if is_best: | |
aliases.append("best") | |
self.run.log_artifact(artifact, aliases=aliases) | |
def close(self): | |
self.run.finish() |