File size: 2,184 Bytes
b5ae7e6
36f3d38
1849dad
36f3d38
1849dad
31cee3d
b5ae7e6
1849dad
31cee3d
1849dad
b5ae7e6
1849dad
 
 
 
 
 
 
 
 
 
 
b113398
 
 
31cee3d
b113398
 
 
 
 
31cee3d
1849dad
b113398
 
1849dad
 
b113398
31cee3d
b5ae7e6
 
 
 
1849dad
b5ae7e6
1849dad
 
 
 
 
b5ae7e6
 
 
 
 
 
 
1849dad
 
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
1849dad
b5ae7e6
 
1849dad
36f3d38
b5ae7e6
1849dad
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import importlib
import inspect
import os

from .artifact import Artifact, Artifactories
from .catalog import PATHS_SEP, GithubCatalog, LocalCatalog
from .utils import Singleton

UNITXT_ARTIFACTORIES_ENV_VAR = "UNITXT_ARTIFACTORIES"

# Usage
non_registered_files = [
    "__init__.py",
    "artifact.py",
    "utils.py",
    "register.py",
    "metric.py",
    "dataset.py",
    "blocks.py",
]


def _register_catalog(catalog: LocalCatalog):
    Artifactories().register_atrifactory(catalog)


def register_local_catalog(catalog_path: str):
    assert os.path.exists(catalog_path), f"Catalog path {catalog_path} does not exist."
    assert os.path.isdir(catalog_path), f"Catalog path {catalog_path} is not a directory."
    _register_catalog(LocalCatalog(location=catalog_path))


def _register_all_catalogs():
    _register_catalog(GithubCatalog())
    _register_catalog(LocalCatalog())
    if UNITXT_ARTIFACTORIES_ENV_VAR in os.environ:
        for path in os.environ[UNITXT_ARTIFACTORIES_ENV_VAR].split(PATHS_SEP):
            _register_catalog(LocalCatalog(location=path))


def _register_all_artifacts():
    dir = os.path.dirname(__file__)
    file_name = os.path.basename(__file__)

    for file in os.listdir(dir):
        if file.endswith(".py") and file not in non_registered_files and file != file_name:
            module_name = file.replace(".py", "")

            module = importlib.import_module("." + module_name, __package__)

            for name, obj in inspect.getmembers(module):
                # Make sure the object is a class
                if inspect.isclass(obj):
                    # Make sure the class is a subclass of Artifact (but not Artifact itself)
                    if issubclass(obj, Artifact) and obj is not Artifact:
                        Artifact.register_class(obj)


class ProjectArtifactRegisterer(metaclass=Singleton):
    def __init__(self):
        if not hasattr(self, "_registered"):
            self._registered = False

        if not self._registered:
            _register_all_catalogs()
            _register_all_artifacts()
            self._registered = True


def register_all_artifacts():
    ProjectArtifactRegisterer()