# ztrain/io.py # Copyright (c) 2024 Praxis Maldevide - cc-by-nc-4.0 granted import os from glob import glob def flatten_index(model_paths : list[str], allow_list : list[str]): flat = [] subtype = [] index = {} ix = 0 for g in sorted(model_paths): name = os.path.basename(g) if name in allow_list: index[name] = ix flat.append(name) if 'base' in g: subtype.append('base') elif 'instruct' in g: subtype.append('instruct') else: subtype.append('other') ix += 1 return index, flat, subtype def list_for_path(path: str, include_folders: list[str], search: str = "/**/*") -> tuple[list[str], list[str], list[str], dict[str, int]]: model_list = sorted([*[ f for f in glob(path + search)]]) group_idx, model_names, subtypes = flatten_index(model_list, include_folders) groups = [[m] for m in model_names] return model_names, subtypes, model_list, group_idx