File size: 1,430 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
"""Manages mapping the dataset name to the database instance."""
import os
import threading
from typing import Type

from .data.dataset import Dataset

_DEFAULT_DATASET_CLS: Type[Dataset]

_CACHED_DATASETS: dict[str, Dataset] = {}

_db_lock = threading.Lock()


def get_dataset(namespace: str, dataset_name: str) -> Dataset:
  """Get the dataset instance."""
  if not _DEFAULT_DATASET_CLS:
    raise ValueError('Default dataset class not set.')
  cache_key = f'{namespace}/{dataset_name}'
  # https://docs.pytest.org/en/latest/example/simple.html#pytest-current-test-environment-variable
  inside_test = 'PYTEST_CURRENT_TEST' in os.environ
  with _db_lock:
    if cache_key not in _CACHED_DATASETS or inside_test:
      _CACHED_DATASETS[cache_key] = _DEFAULT_DATASET_CLS(
        namespace=namespace, dataset_name=dataset_name)
    return _CACHED_DATASETS[cache_key]


def remove_dataset_from_cache(namespace: str, dataset_name: str) -> None:
  """Remove the dataset from the db manager cache."""
  cache_key = f'{namespace}/{dataset_name}'
  with _db_lock:
    if cache_key in _CACHED_DATASETS:
      del _CACHED_DATASETS[cache_key]


# TODO(nsthorat): Make this a registry once we have multiple dataset implementations. This breaks a
# circular dependency.
def set_default_dataset_cls(dataset_cls: Type[Dataset]) -> None:
  """Set the default dataset class."""
  global _DEFAULT_DATASET_CLS
  _DEFAULT_DATASET_CLS = dataset_cls