Spaces:
Sleeping
Sleeping
import logging | |
from argparse import ArgumentParser | |
from datetime import datetime | |
from logging import Logger | |
from pathlib import Path | |
from typing import Any, Mapping, Optional, Union | |
import ignite.distributed as idist | |
import torch | |
import yaml | |
from ignite.contrib.engines import common | |
from ignite.engine import Engine | |
from ignite.engine.events import Events | |
from ignite.handlers import Checkpoint, DiskSaver, global_step_from_engine | |
from ignite.handlers.early_stopping import EarlyStopping | |
from ignite.handlers.terminate_on_nan import TerminateOnNan | |
from ignite.handlers.time_limit import TimeLimit | |
from ignite.utils import setup_logger | |
def setup_parser(config_path="base_config.yaml"): | |
with open(config_path, "r") as f: | |
config = yaml.safe_load(f.read()) | |
parser = ArgumentParser() | |
parser.add_argument("--config", default=None, type=str) | |
parser.add_argument("--backend", default=None, type=str) | |
for k, v in config.items(): | |
if isinstance(v, bool): | |
parser.add_argument(f"--{k}", action="store_true") | |
else: | |
parser.add_argument(f"--{k}", default=v, type=type(v)) | |
return parser | |
def log_metrics(engine: Engine, tag: str) -> None: | |
"""Log `engine.state.metrics` with given `engine` and `tag`. | |
Parameters | |
---------- | |
engine | |
instance of `Engine` which metrics to log. | |
tag | |
a string to add at the start of output. | |
""" | |
metrics_format = "{0} [{1}/{2}]: {3}".format( | |
tag, engine.state.epoch, engine.state.iteration, engine.state.metrics | |
) | |
epoch_size = engine.state.epoch_length | |
local_iteration = engine.state.iteration - epoch_size * (engine.state.epoch - 1) | |
metrics_format = f"{tag} Epoch {engine.state.epoch} - [{local_iteration} / {epoch_size}] : {engine.state.metrics}" | |
engine.logger.info(metrics_format) | |
def resume_from( | |
to_load: Mapping, | |
checkpoint_fp: Union[str, Path], | |
logger: Logger, | |
strict: bool = True, | |
model_dir: Optional[str] = None, | |
) -> None: | |
"""Loads state dict from a checkpoint file to resume the training. | |
Parameters | |
---------- | |
to_load | |
a dictionary with objects, e.g. {“model”: model, “optimizer”: optimizer, ...} | |
checkpoint_fp | |
path to the checkpoint file | |
logger | |
to log info about resuming from a checkpoint | |
strict | |
whether to strictly enforce that the keys in `state_dict` match the keys | |
returned by this module’s `state_dict()` function. Default: True | |
model_dir | |
directory in which to save the object | |
""" | |
if isinstance(checkpoint_fp, str) and checkpoint_fp.startswith("https://"): | |
checkpoint = torch.hub.load_state_dict_from_url( | |
checkpoint_fp, | |
model_dir=model_dir, | |
map_location="cpu", | |
check_hash=True, | |
) | |
else: | |
if isinstance(checkpoint_fp, str): | |
checkpoint_fp = Path(checkpoint_fp) | |
if not checkpoint_fp.exists(): | |
raise FileNotFoundError(f"Given {str(checkpoint_fp)} does not exist.") | |
checkpoint = torch.load(checkpoint_fp, map_location="cpu") | |
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint, strict=strict) | |
logger.info("Successfully resumed from a checkpoint: %s", checkpoint_fp) | |
def setup_output_dir(config: Any, rank: int) -> Path: | |
"""Create output folder.""" | |
if rank == 0: | |
now = datetime.now().strftime("%Y%m%d-%H%M%S") | |
name = f"{now}-backend-{config.backend}-lr-{config.lr}" | |
path = Path(config.output_dir, name) | |
path.mkdir(parents=True, exist_ok=True) | |
config.output_dir = path.as_posix() | |
return Path(idist.broadcast(config.output_dir, src=0)) | |
def setup_logging(config: Any) -> Logger: | |
"""Setup logger with `ignite.utils.setup_logger()`. | |
Parameters | |
---------- | |
config | |
config object. config has to contain `verbose` and `output_dir` attribute. | |
Returns | |
------- | |
logger | |
an instance of `Logger` | |
""" | |
green = "\033[32m" | |
reset = "\033[0m" | |
logger = setup_logger( | |
name=f"{green}[ignite]{reset}", | |
level=logging.DEBUG if config.debug else logging.INFO, | |
format="%(name)s: %(message)s", | |
filepath=config.output_dir / "training-info.log", | |
) | |
return logger | |