thomaspaniagua
QuadAttack release
71f183c
raw
history blame contribute delete
No virus
4.32 kB
import logging
from argparse import ArgumentParser
from datetime import datetime
from logging import Logger
from pathlib import Path
from typing import Any, Mapping, Optional, Union
import ignite.distributed as idist
import torch
import yaml
from ignite.contrib.engines import common
from ignite.engine import Engine
from ignite.engine.events import Events
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine
from ignite.handlers.early_stopping import EarlyStopping
from ignite.handlers.terminate_on_nan import TerminateOnNan
from ignite.handlers.time_limit import TimeLimit
from ignite.utils import setup_logger
def setup_parser(config_path="base_config.yaml"):
with open(config_path, "r") as f:
config = yaml.safe_load(f.read())
parser = ArgumentParser()
parser.add_argument("--config", default=None, type=str)
parser.add_argument("--backend", default=None, type=str)
for k, v in config.items():
if isinstance(v, bool):
parser.add_argument(f"--{k}", action="store_true")
else:
parser.add_argument(f"--{k}", default=v, type=type(v))
return parser
def log_metrics(engine: Engine, tag: str) -> None:
"""Log `engine.state.metrics` with given `engine` and `tag`.
Parameters
----------
engine
instance of `Engine` which metrics to log.
tag
a string to add at the start of output.
"""
metrics_format = "{0} [{1}/{2}]: {3}".format(
tag, engine.state.epoch, engine.state.iteration, engine.state.metrics
)
epoch_size = engine.state.epoch_length
local_iteration = engine.state.iteration - epoch_size * (engine.state.epoch - 1)
metrics_format = f"{tag} Epoch {engine.state.epoch} - [{local_iteration} / {epoch_size}] : {engine.state.metrics}"
engine.logger.info(metrics_format)
def resume_from(
to_load: Mapping,
checkpoint_fp: Union[str, Path],
logger: Logger,
strict: bool = True,
model_dir: Optional[str] = None,
) -> None:
"""Loads state dict from a checkpoint file to resume the training.
Parameters
----------
to_load
a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...}
checkpoint_fp
path to the checkpoint file
logger
to log info about resuming from a checkpoint
strict
whether to strictly enforce that the keys in `state_dict` match the keys
returned by this module’s `state_dict()` function. Default: True
model_dir
directory in which to save the object
"""
if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"):
checkpoint = torch.hub.load_state_dict_from_url(
checkpoint_fp,
model_dir=model_dir,
map_location="cpu",
check_hash=True,
)
else:
if isinstance(checkpoint_fp, str):
checkpoint_fp = Path(checkpoint_fp)
if not checkpoint_fp.exists():
raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.")
checkpoint = torch.load(checkpoint_fp, map_location="cpu")
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict)
logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp)
def setup_output_dir(config: Any, rank: int) -> Path:
"""Create output folder."""
if rank == 0:
now = datetime.now().strftime("%Y%m%d-%H%M%S")
name = f"{now}-backend-{config.backend}-lr-{config.lr}"
path = Path(config.output_dir, name)
path.mkdir(parents=True, exist_ok=True)
config.output_dir = path.as_posix()
return Path(idist.broadcast(config.output_dir, src=0))
def setup_logging(config: Any) -> Logger:
"""Setup logger with `ignite.utils.setup_logger()`.
Parameters
----------
config
config object. config has to contain `verbose` and `output_dir` attribute.
Returns
-------
logger
an instance of `Logger`
"""
green = "\033[32m"
reset = "\033[0m"
logger = setup_logger(
name=f"{green}[ignite]{reset}",
level=logging.DEBUG if config.debug else logging.INFO,
format="%(name)s: %(message)s",
filepath=config.output_dir / "training-info.log",
)
return logger