cyun9286's picture
Add application file
f53b39e
raw
history blame
7.68 kB
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# Code borrowed from TLC - https://www.internalfb.com/code/fbsource/fbcode/pytorch/tlc/torchtlc/loggers/tensorboard.py
import atexit
import functools
import logging
import sys
import uuid
from typing import Any, Dict, Optional, Union
from hydra.utils import instantiate
from iopath.common.file_io import g_pathmgr
from numpy import ndarray
from torch import Tensor
from torch.utils.tensorboard import SummaryWriter
from training.utils.train_utils import get_machine_local_and_dist_rank, makedir
Scalar = Union[Tensor, ndarray, int, float]
def make_tensorboard_logger(log_dir: str, **writer_kwargs: Any):
makedir(log_dir)
summary_writer_method = SummaryWriter
return TensorBoardLogger(
path=log_dir, summary_writer_method=summary_writer_method, **writer_kwargs
)
class TensorBoardWriterWrapper:
"""
A wrapper around a SummaryWriter object.
"""
def __init__(
self,
path: str,
*args: Any,
filename_suffix: str = None,
summary_writer_method: Any = SummaryWriter,
**kwargs: Any,
) -> None:
"""Create a new TensorBoard logger.
On construction, the logger creates a new events file that logs
will be written to. If the environment variable `RANK` is defined,
logger will only log if RANK = 0.
NOTE: If using the logger with distributed training:
- This logger can call collective operations
- Logs will be written on rank 0 only
- Logger must be constructed synchronously *after* initializing distributed process group.
Args:
path (str): path to write logs to
*args, **kwargs: Extra arguments to pass to SummaryWriter
"""
self._writer: Optional[SummaryWriter] = None
_, self._rank = get_machine_local_and_dist_rank()
self._path: str = path
if self._rank == 0:
logging.info(
f"TensorBoard SummaryWriter instantiated. Files will be stored in: {path}"
)
self._writer = summary_writer_method(
log_dir=path,
*args,
filename_suffix=filename_suffix or str(uuid.uuid4()),
**kwargs,
)
else:
logging.debug(
f"Not logging meters on this host because env RANK: {self._rank} != 0"
)
atexit.register(self.close)
@property
def writer(self) -> Optional[SummaryWriter]:
return self._writer
@property
def path(self) -> str:
return self._path
def flush(self) -> None:
"""Writes pending logs to disk."""
if not self._writer:
return
self._writer.flush()
def close(self) -> None:
"""Close writer, flushing pending logs to disk.
Logs cannot be written after `close` is called.
"""
if not self._writer:
return
self._writer.close()
self._writer = None
class TensorBoardLogger(TensorBoardWriterWrapper):
"""
A simple logger for TensorBoard.
"""
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
"""Add multiple scalar values to TensorBoard.
Args:
payload (dict): dictionary of tag name and scalar value
step (int, Optional): step value to record
"""
if not self._writer:
return
for k, v in payload.items():
self.log(k, v, step)
def log(self, name: str, data: Scalar, step: int) -> None:
"""Add scalar data to TensorBoard.
Args:
name (string): tag name used to group scalars
data (float/int/Tensor): scalar data to log
step (int, optional): step value to record
"""
if not self._writer:
return
self._writer.add_scalar(name, data, global_step=step, new_style=True)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
"""Add hyperparameter data to TensorBoard.
Args:
hparams (dict): dictionary of hyperparameter names and corresponding values
meters (dict): dictionary of name of meter and corersponding values
"""
if not self._writer:
return
self._writer.add_hparams(hparams, meters)
class Logger:
"""
A logger class that can interface with multiple loggers. It now supports tensorboard only for simplicity, but you can extend it with your own logger.
"""
def __init__(self, logging_conf):
# allow turning off TensorBoard with "should_log: false" in config
tb_config = logging_conf.tensorboard_writer
tb_should_log = tb_config and tb_config.pop("should_log", True)
self.tb_logger = instantiate(tb_config) if tb_should_log else None
def log_dict(self, payload: Dict[str, Scalar], step: int) -> None:
if self.tb_logger:
self.tb_logger.log_dict(payload, step)
def log(self, name: str, data: Scalar, step: int) -> None:
if self.tb_logger:
self.tb_logger.log(name, data, step)
def log_hparams(
self, hparams: Dict[str, Scalar], meters: Dict[str, Scalar]
) -> None:
if self.tb_logger:
self.tb_logger.log_hparams(hparams, meters)
# cache the opened file object, so that different calls to `setup_logger`
# with the same file name can safely write to the same file.
@functools.lru_cache(maxsize=None)
def _cached_log_stream(filename):
# we tune the buffering value so that the logs are updated
# frequently.
log_buffer_kb = 10 * 1024 # 10KB
io = g_pathmgr.open(filename, mode="a", buffering=log_buffer_kb)
atexit.register(io.close)
return io
def setup_logging(
name,
output_dir=None,
rank=0,
log_level_primary="INFO",
log_level_secondary="ERROR",
):
"""
Setup various logging streams: stdout and file handlers.
For file handlers, we only setup for the master gpu.
"""
# get the filename if we want to log to the file as well
log_filename = None
if output_dir:
makedir(output_dir)
if rank == 0:
log_filename = f"{output_dir}/log.txt"
logger = logging.getLogger(name)
logger.setLevel(log_level_primary)
# create formatter
FORMAT = "%(levelname)s %(asctime)s %(filename)s:%(lineno)4d: %(message)s"
formatter = logging.Formatter(FORMAT)
# Cleanup any existing handlers
for h in logger.handlers:
logger.removeHandler(h)
logger.root.handlers = []
# setup the console handler
console_handler = logging.StreamHandler(sys.stdout)
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)
if rank == 0:
console_handler.setLevel(log_level_primary)
else:
console_handler.setLevel(log_level_secondary)
# we log to file as well if user wants
if log_filename and rank == 0:
file_handler = logging.StreamHandler(_cached_log_stream(log_filename))
file_handler.setLevel(log_level_primary)
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
logging.root = logger
def shutdown_logging():
"""
After training is done, we ensure to shut down all the logger streams.
"""
logging.info("Shutting down loggers...")
handlers = logging.root.handlers
for handler in handlers:
handler.close()