paul hilders
Add new version of demo for IEAI course
0241217
raw
history blame
2.69 kB
"""
Utilities for input-output loading/saving.
"""
from typing import Any, List
import yaml
import pickle
import json
class PrettySafeLoader(yaml.SafeLoader):
"""Custom loader for reading YAML files"""
def construct_python_tuple(self, node):
return tuple(self.construct_sequence(node))
PrettySafeLoader.add_constructor(
u'tag:yaml.org,2002:python/tuple',
PrettySafeLoader.construct_python_tuple
)
def load_yml(path: str, loader_type: str = 'default'):
"""Read params from a yml file.
Args:
path (str): path to the .yml file
loader_type (str, optional): type of loader used to load yml files. Defaults to 'default'.
Returns:
Any: object (typically dict) loaded from .yml file
"""
assert loader_type in ['default', 'safe']
loader = yaml.Loader if (loader_type == "default") else PrettySafeLoader
with open(path, 'r') as f:
data = yaml.load(f, Loader=loader)
return data
def save_yml(data: dict, path: str):
"""Save params in the given yml file path.
Args:
data (dict): data object to save
path (str): path to .yml file to be saved
"""
with open(path, 'w') as f:
yaml.dump(data, f, default_flow_style=False)
def load_pkl(path: str, encoding: str = "ascii") -> Any:
"""Loads a .pkl file.
Args:
path (str): path to the .pkl file
encoding (str, optional): encoding to use for loading. Defaults to "ascii".
Returns:
Any: unpickled object
"""
return pickle.load(open(path, "rb"), encoding=encoding)
def save_pkl(data: Any, path: str) -> None:
"""Saves given object into .pkl file
Args:
data (Any): object to be saved
path (str): path to the location to be saved at
"""
with open(path, 'wb') as f:
pickle.dump(data, f)
def load_json(path: str) -> dict:
"""Helper to load json file"""
with open(path, 'rb') as f:
data = json.load(f)
return data
def save_json(data: dict, path: str):
"""Helper to save `dict` as .json file."""
with open(path, 'w') as f:
json.dump(data, f)
def load_txt(path: str) -> List:
"""Loads lines of a .txt file.
Args:
path (str): path to the .txt file
Returns:
List: lines of .txt file
"""
with open(path) as f:
lines = f.read().splitlines()
return lines
def save_txt(data: dict, path: str):
"""Writes data (lines) to a txt file.
Args:
data (dict): List of strings
path (str): path to .txt file
"""
assert isinstance(data, list)
lines = "\n".join(data)
with open(path, "w") as f:
f.write(str(lines))