File size: 2,786 Bytes
5147f0d
 
1b98aa7
 
 
 
 
 
 
 
5147f0d
fa0fdab
 
5147f0d
fa0fdab
 
 
1b98aa7
5147f0d
1b98aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import os

from .artifact import Artifact, Artifactory, register_atrifactory
from .file_utils import get_all_files_in_dir


class Catalog(Artifactory):
    name: str = None
    location: str = None


try:
    import unitxt

    catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
except ImportError:
    catalog_path = os.path.dirname(__file__) + "/catalog"


class LocalCatalog(Catalog):
    name: str = "local"
    location: str = catalog_path

    @property
    def path_dict(self):
        result = {}
        for path in get_all_files_in_dir(self.location, recursive=True, file_extension=".json"):
            name = os.path.splitext(os.path.basename(path))[0]
            result[name] = path
        return result

    def path(self, artifact_identifier: str):
        return self.path_dict.get(artifact_identifier, None)

    def load(self, artifact_identifier: str):
        assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
        path = self.path(artifact_identifier)
        artifact_instance = Artifact.load(path)
        return artifact_instance

    def __getitem__(self, name) -> Artifact:
        return self.load(name)

    def __contains__(self, artifact_identifier: str):
        if not os.path.exists(self.location):
            return False
        path = self.path(artifact_identifier)
        if path is None:
            return False
        return os.path.exists(path) and os.path.isfile(path)

    def save(self, artifact: Artifact, artifact_identifier: str, collection: str, overwrite: bool = False):
        assert isinstance(artifact, Artifact), "Artifact must be an instance of Artifact"

        if not overwrite:
            assert (
                artifact_identifier not in self
            ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"

        collection_dir = os.path.join(self.location, collection)
        os.makedirs(collection_dir, exist_ok=True)
        path = os.path.join(collection_dir, artifact_identifier + ".json")
        artifact.save(path)


register_atrifactory(LocalCatalog())

try:
    import unitxt

    library_catalog = LocalCatalog("library", unitxt.__path__[0] + "/catalog")
    register_atrifactory(library_catalog)
except:
    pass
# create a catalog for the community


class CommunityCatalog(Catalog):
    name = "community"
    location = "https://raw.githubusercontent.com/unitxt/unitxt/main/catalog/community.json"

    def load(self, artifact_identifier: str):
        pass


def add_to_catalog(artifact: Artifact, name: str, collection=str, catalog: Catalog = None, overwrite: bool = False):
    if catalog is None:
        catalog = LocalCatalog()
    catalog.save(artifact, name, collection, overwrite=overwrite)