""" Utilities for input-output loading/saving. """ from typing import Any, List import yaml import pickle import json class PrettySafeLoader(yaml.SafeLoader): """Custom loader for reading YAML files""" def construct_python_tuple(self, node): return tuple(self.construct_sequence(node)) PrettySafeLoader.add_constructor( u'tag:yaml.org,2002:python/tuple', PrettySafeLoader.construct_python_tuple ) def load_yml(path: str, loader_type: str = 'default'): """Read params from a yml file. Args: path (str): path to the .yml file loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'. Returns: Any: object (typically dict) loaded from .yml file """ assert loader_type in ['default', 'safe'] loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader with open(path, 'r') as f: data = yaml.load(f, Loader=loader) return data def save_yml(data: dict, path: str): """Save params in the given yml file path. Args: data (dict): data object to save path (str): path to .yml file to be saved """ with open(path, 'w') as f: yaml.dump(data, f, default_flow_style=False) def load_pkl(path: str, encoding: str = "ascii") -> Any: """Loads a .pkl file. Args: path (str): path to the .pkl file encoding (str, optional): encoding to use for loading. Defaults to "ascii". Returns: Any: unpickled object """ return pickle.load(open(path, "rb"), encoding=encoding) def save_pkl(data: Any, path: str) -> None: """Saves given object into .pkl file Args: data (Any): object to be saved path (str): path to the location to be saved at """ with open(path, 'wb') as f: pickle.dump(data, f) def load_json(path: str) -> dict: """Helper to load json file""" with open(path, 'rb') as f: data = json.load(f) return data def save_json(data: dict, path: str): """Helper to save `dict` as .json file.""" with open(path, 'w') as f: json.dump(data, f) def load_txt(path: str) -> List: """Loads lines of a .txt file. Args: path (str): path to the .txt file Returns: List: lines of .txt file """ with open(path) as f: lines = f.read().splitlines() return lines def save_txt(data: dict, path: str): """Writes data (lines) to a txt file. Args: data (dict): List of strings path (str): path to .txt file """ assert isinstance(data, list) lines = "\n".join(data) with open(path, "w") as f: f.write(str(lines))