File size: 1,363 Bytes
e4f9cbe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
"""Source registry for the dataset sources."""

from typing import Type, Union

from .source import Source

SOURCE_REGISTRY: dict[str, Type[Source]] = {}


def register_source(source_cls: Type[Source]) -> None:
  """Register a source configuration globally."""
  if source_cls.name in SOURCE_REGISTRY:
    raise ValueError(f'Source "{source_cls.name}" has already been registered!')

  SOURCE_REGISTRY[source_cls.name] = source_cls


def get_source_cls(source_name: str) -> Type[Source]:
  """Return a registered source given the name in the registry."""
  if source_name not in SOURCE_REGISTRY:
    raise ValueError(f'Source "{source_name}" not found in the registry')

  return SOURCE_REGISTRY[source_name]


def registered_sources() -> dict[str, Type[Source]]:
  """Return all registered sources."""
  return SOURCE_REGISTRY


def resolve_source(source: Union[dict, Source]) -> Source:
  """Resolve a generic source base class to a specific source class."""
  if isinstance(source, Source):
    # The source is already parsed.
    return source

  source_name = source.get('source_name')
  if not source_name:
    raise ValueError('"source_name" needs to be defined in the json dict.')

  source_cls = get_source_cls(source_name)
  return source_cls(**source)


def clear_source_registry() -> None:
  """Clear the source registry."""
  SOURCE_REGISTRY.clear()