nikhil_staging / lilac /db_manager.py
nsthorat's picture
Push
56cce61
raw
history blame
1.43 kB
"""Manages mapping the dataset name to the database instance."""
import os
import threading
from typing import Type
from .data.dataset import Dataset
_DEFAULT_DATASET_CLS: Type[Dataset]
_CACHED_DATASETS: dict[str, Dataset] = {}
_db_lock = threading.Lock()
def get_dataset(namespace: str, dataset_name: str) -> Dataset:
"""Get the dataset instance."""
if not _DEFAULT_DATASET_CLS:
raise ValueError('Default dataset class not set.')
cache_key = f'{namespace}/{dataset_name}'
# https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
inside_test = 'PYTEST_CURRENT_TEST' in os.environ
with _db_lock:
if cache_key not in _CACHED_DATASETS or inside_test:
_CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS(
namespace=namespace, dataset_name=dataset_name)
return _CACHED_DATASETS[cache_key]
def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None:
"""Remove the dataset from the db manager cache."""
cache_key = f'{namespace}/{dataset_name}'
with _db_lock:
if cache_key in _CACHED_DATASETS:
del _CACHED_DATASETS[cache_key]
# TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a
# circular dependency.
def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None:
"""Set the default dataset class."""
global _DEFAULT_DATASET_CLS
_DEFAULT_DATASET_CLS = dataset_cls