chatlawv1 / trlx /tests /test_configs.py
teachyourselfcoding's picture
Upload 245 files
fa6856c
raw
history blame
No virus
1.47 kB
import os
from typing import List
from trlx.data.configs import TRLConfig
def _get_config_dirs(dir: str, config_dir_name: str = "configs") -> List[str]:
"""Returns all sub-directories of `dir` named `configs`."""
config_dirs = []
for root, dirs, _ in os.walk(dir):
for d in dirs:
if d == config_dir_name:
config_dirs.append(os.path.join(root, d))
return config_dirs
def _get_yaml_filepaths(dir: str) -> List[str]:
"""Returns a list of `yml` filepaths in `dir`."""
filepaths = []
for file in os.listdir(dir):
if file.endswith(".yml"):
filepaths.append(os.path.join(dir, file))
return filepaths
def test_repo_trl_configs():
"""Tests to ensure all default configs in the repository are valid."""
config_dirs = ["configs", *_get_config_dirs("examples")]
config_files = sum(map(_get_yaml_filepaths, config_dirs), []) # sum for flat-map behavior
for file in config_files:
assert os.path.isfile(file), f"Config file {file} does not exist."
assert file.endswith(".yml"), f"Config file {file} is not a yaml file."
try:
config = TRLConfig.load_yaml(file)
assert (
config.train.entity_name is None
), f"Unexpected entity name in config file `{file}`. Remove before pushing to repo."
except Exception as e:
assert False, f"Failed to load config file `{file}` with error `{e}`"