ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
raw
history blame
6.49 kB
import json
import os
import sys
import time
import warnings
from pathlib import Path
from typing import Callable, Dict, Optional
from omegaconf import DictConfig
from pytorch_lightning.utilities import rank_zero_only
from src.utils.logging_utils import close_loggers, get_pylogger
from src.utils.rich_utils import enforce_tags, print_config_tree
log = get_pylogger(__name__)
def task_wrapper(task_func: Callable) -> Callable:
"""Optional decorator that wraps the task function in extra utilities.
Makes multirun more resistant to failure.
Utilities:
- Calling the `utils.extras()` before the task is started
- Calling the `utils.close_loggers()` after the task is finished
- Logging the exception if occurs
- Logging the task total execution time
- Logging the output dir
"""
def wrap(cfg: DictConfig):
# apply extra utilities
extras(cfg)
# execute the task
start_time = time.time()
try:
task_result = task_func(cfg=cfg)
except Exception as ex:
log.exception("") # save exception to `.log` file
raise ex
finally:
path = Path(cfg.paths.output_dir, "exec_time.log")
content = f"'{cfg.pipeline_type}' execution time: {time.time() - start_time} (s)"
save_file(path, content) # save task execution time (even if exception occurs)
close_loggers() # close loggers (even if exception occurs so multirun won't fail)
log.info(f"Output dir: {cfg.paths.output_dir}")
return task_result
return wrap
def extras(cfg: DictConfig) -> None:
"""Applies optional utilities before the task is started.
Utilities:
- Ignoring python warnings
- Setting tags from command line
- Rich config printing
"""
# return if no `extras` config
if not cfg.get("extras"):
log.warning("Extras config not found! <cfg.extras=null>")
return
# disable python warnings
if cfg.extras.get("ignore_warnings"):
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
warnings.filterwarnings("ignore")
# prompt user to input tags from command line if none are provided in the config
if cfg.extras.get("enforce_tags"):
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
enforce_tags(cfg, save_to_file=True)
# pretty print config tree using Rich library
if cfg.extras.get("print_config"):
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
print_config_tree(cfg, resolve=True, save_to_file=True)
@rank_zero_only
def save_file(path: str, content: str) -> None:
"""Save file in rank zero mode (only on one process in multi-GPU setup)."""
with open(path, "w+") as file:
file.write(content)
def load_value_from_file(path: str, split_path_key: str = ":", split_key_parts: str = "/") -> Dict:
"""Load a value from a file. The path can point to elements within the file (see split_path_key
parameter) and that can be nested (see split_key_parts parameter). For now, only .json files
are supported.
Args:
path: path to the file (and data within the file)
split_path_key: split the path on this value to get the path to the file and the key within the file
split_key_parts: the value to split the key on to get the nested keys
"""
parts_path = path.split(split_path_key, maxsplit=1)
file_extension = os.path.splitext(parts_path[0])[1]
if file_extension == ".json":
with open(parts_path[0], "r") as f:
data = json.load(f)
else:
raise ValueError(f"Expected .json file, got {file_extension}")
if len(parts_path) == 1:
return data
keys = parts_path[1].split(split_key_parts)
for key in keys:
data = data[key]
return data
def replace_sys_args_with_values_from_files(
load_prefix: str = "LOAD_ARG:",
load_multi_prefix: str = "LOAD_MULTI_ARG:",
**load_value_from_file_kwargs,
) -> None:
"""Replaces arguments in sys.argv with values loaded from files.
Examples:
# config.json contains {"a": 1, "b": 2}
python train.py LOAD_ARG:job_return_value.json
# this will pass "{a:1,b:2}" as the first argument to train.py
# config.json contains [1, 2, 3]
python train.py LOAD_MULTI_ARG:job_return_value.json
# this will pass "1,2,3" as the first argument to train.py
# config.json contains {"model": {"ouput_dir": ["path1", "path2"], f1: [0.7, 0.6]}}
python train.py load_model=LOAD_ARG:job_return_value.json:model/output_dir
# this will pass "load_model=path1,path2" to train.py
Args:
load_prefix: the prefix to use for loading a single value from a file
load_multi_prefix: the prefix to use for loading a list of values from a file
**load_value_from_file_kwargs: additional kwargs to pass to load_value_from_file
"""
updated_args = []
for arg in sys.argv[1:]:
is_multirun_arg = False
if load_prefix in arg:
parts = arg.split(load_prefix, maxsplit=1)
elif load_multi_prefix in arg:
parts = arg.split(load_multi_prefix, maxsplit=1)
is_multirun_arg = True
else:
updated_args.append(arg)
continue
if len(parts) == 2:
log.warning(f'Replacing argument value for "{parts[0]}" with content from {parts[1]}')
json_value = load_value_from_file(parts[1], **load_value_from_file_kwargs)
json_value_str = json.dumps(json_value)
# replace quotes and spaces
json_value_str = json_value_str.replace('"', "").replace(" ", "")
# remove outer brackets
if is_multirun_arg:
if not isinstance(json_value, list):
raise ValueError(
f"Expected list for multirun argument, got {type(json_value)}. If you just want "
f"to set a single value, use {load_prefix} instead of {load_multi_prefix}."
)
json_value_str = json_value_str[1:-1]
# add outer quotes
modified_arg = f"{parts[0]}{json_value_str}"
updated_args.append(modified_arg)
else:
updated_args.append(arg)
# Set sys.argv to the updated arguments
sys.argv = [sys.argv[0]] + updated_args