# Copyright (c) Facebook, Inc. and its affiliates. # Copyright (c) Meta Platforms, Inc. All Rights Reserved import os import wandb from detectron2.utils import comm from detectron2.utils.events import EventWriter, get_event_storage def setup_wandb(cfg, args): if comm.is_main_process(): init_args = { k.lower(): v for k, v in cfg.WANDB.items() if isinstance(k, str) and k not in ["config", "name"] } # only include most related part to avoid too big table # TODO: add configurable params to select which part of `cfg` should be saved in config if "config_exclude_keys" in init_args: init_args["config"] = cfg init_args["config"]["cfg_file"] = args.config_file else: init_args["config"] = { "model": cfg.MODEL, "solver": cfg.SOLVER, "cfg_file": args.config_file, } if ("name" not in init_args) or (init_args["name"] is None): init_args["name"] = os.path.basename(args.config_file) wandb.init(**init_args) class BaseRule(object): def __call__(self, target): return target class IsIn(BaseRule): def __init__(self, keyword: str): self.keyword = keyword def __call__(self, target): return self.keyword in target class Prefix(BaseRule): def __init__(self, keyword: str): self.keyword = keyword def __call__(self, target): return "/".join([self.keyword, target]) class WandbWriter(EventWriter): """ Write all scalars to a tensorboard file. """ def __init__(self): """ Args: log_dir (str): the directory to save the output events kwargs: other arguments passed to `torch.utils.tensorboard.SummaryWriter(...)` """ self._last_write = -1 self._group_rules = [ (IsIn("/"), BaseRule()), (IsIn("loss"), Prefix("train")), ] def write(self): storage = get_event_storage() def _group_name(scalar_name): for (rule, op) in self._group_rules: if rule(scalar_name): return op(scalar_name) return scalar_name stats = { _group_name(name): scalars[0] for name, scalars in storage.latest().items() if scalars[1] > self._last_write } if len(stats) > 0: self._last_write = max([v[1] for k, v in storage.latest().items()]) # storage.put_{image,histogram} is only meant to be used by # tensorboard writer. So we access its internal fields directly from here. if len(storage._vis_data) >= 1: stats["image"] = [ wandb.Image(img, caption=img_name) for img_name, img, step_num in storage._vis_data ] # Storage stores all image data and rely on this writer to clear them. # As a result it assumes only one writer will use its image data. # An alternative design is to let storage store limited recent # data (e.g. only the most recent image) that all writers can access. # In that case a writer may not see all image data if its period is long. storage.clear_images() if len(storage._histograms) >= 1: def create_bar(tag, bucket_limits, bucket_counts, **kwargs): data = [ [label, val] for (label, val) in zip(bucket_limits, bucket_counts) ] table = wandb.Table(data=data, columns=["label", "value"]) return wandb.plot.bar(table, "label", "value", title=tag) stats["hist"] = [create_bar(**params) for params in storage._histograms] storage.clear_histograms() if len(stats) == 0: return wandb.log(stats, step=storage.iter) def close(self): wandb.finish()