"""Manages mapping the dataset name to the database instance.""" import os import pathlib import threading from typing import Optional, Type, Union import yaml from pydantic import BaseModel from .config import DatasetConfig from .data.dataset import Dataset from .data.dataset_duckdb import get_config_filepath from .utils import get_datasets_dir _DEFAULT_DATASET_CLS: Type[Dataset] _CACHED_DATASETS: dict[str, Dataset] = {} _db_lock = threading.Lock() def get_dataset(namespace: str, dataset_name: str) -> Dataset: """Get the dataset instance.""" if not _DEFAULT_DATASET_CLS: raise ValueError('Default dataset class not set.') cache_key = f'{namespace}/{dataset_name}' # https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable inside_test = 'PYTEST_CURRENT_TEST' in os.environ with _db_lock: if cache_key not in _CACHED_DATASETS or inside_test: _CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS( namespace=namespace, dataset_name=dataset_name) return _CACHED_DATASETS[cache_key] def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None: """Remove the dataset from the db manager cache.""" cache_key = f'{namespace}/{dataset_name}' with _db_lock: if cache_key in _CACHED_DATASETS: del _CACHED_DATASETS[cache_key] class DatasetInfo(BaseModel): """Information about a dataset.""" namespace: str dataset_name: str description: Optional[str] = None tags: list[str] = [] def list_datasets(base_dir: Union[str, pathlib.Path]) -> list[DatasetInfo]: """List the datasets in a data directory.""" datasets_path = get_datasets_dir(base_dir) # Skip if 'datasets' doesn't exist. if not os.path.isdir(datasets_path): return [] dataset_infos: list[DatasetInfo] = [] for namespace in os.listdir(datasets_path): dataset_dir = os.path.join(datasets_path, namespace) # Skip if namespace is not a directory. if not os.path.isdir(dataset_dir): continue if namespace.startswith('.'): continue for dataset_name in os.listdir(dataset_dir): # Skip if dataset_name is not a directory. dataset_path = os.path.join(dataset_dir, dataset_name) if not os.path.isdir(dataset_path): continue if dataset_name.startswith('.'): continue # Open the config file to read the tags. We avoid instantiating a dataset for now to reduce # the overhead of listing datasets. config_filepath = get_config_filepath(namespace, dataset_name) tags = [] if os.path.exists(config_filepath): with open(config_filepath) as f: config = DatasetConfig(**yaml.safe_load(f)) tags = config.tags dataset_infos.append(DatasetInfo(namespace=namespace, dataset_name=dataset_name, tags=tags)) return dataset_infos # TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a # circular dependency. def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None: """Set the default dataset class.""" global _DEFAULT_DATASET_CLS _DEFAULT_DATASET_CLS = dataset_cls