Spaces:
Runtime error
Runtime error
"""Manages mapping the dataset name to the database instance.""" | |
import os | |
import threading | |
from typing import Type | |
from .data.dataset import Dataset | |
_DEFAULT_DATASET_CLS: Type[Dataset] | |
_CACHED_DATASETS: dict[str, Dataset] = {} | |
_db_lock = threading.Lock() | |
def get_dataset(namespace: str, dataset_name: str) -> Dataset: | |
"""Get the dataset instance.""" | |
if not _DEFAULT_DATASET_CLS: | |
raise ValueError('Default dataset class not set.') | |
cache_key = f'{namespace}/{dataset_name}' | |
# https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable | |
inside_test = 'PYTEST_CURRENT_TEST' in os.environ | |
with _db_lock: | |
if cache_key not in _CACHED_DATASETS or inside_test: | |
_CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS( | |
namespace=namespace, dataset_name=dataset_name) | |
return _CACHED_DATASETS[cache_key] | |
def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None: | |
"""Remove the dataset from the db manager cache.""" | |
cache_key = f'{namespace}/{dataset_name}' | |
with _db_lock: | |
if cache_key in _CACHED_DATASETS: | |
del _CACHED_DATASETS[cache_key] | |
# TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a | |
# circular dependency. | |
def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None: | |
"""Set the default dataset class.""" | |
global _DEFAULT_DATASET_CLS | |
_DEFAULT_DATASET_CLS = dataset_cls | |