File size: 3,913 Bytes
5147f0d
12ca412
 
334775e
5a833c3
12ca412
5147f0d
99056a8
19fbe98
811dd6e
12ca412
19fbe98
5a833c3
 
1b98aa7
 
 
 
 
 
5147f0d
fa0fdab
 
5147f0d
5a833c3
 
 
 
fa0fdab
5a833c3
 
 
1b98aa7
5147f0d
1b98aa7
 
12ca412
1b98aa7
 
334775e
 
 
90e8224
 
 
1b98aa7
 
334775e
 
 
1b98aa7
334775e
1b98aa7
 
 
 
 
 
 
 
 
 
 
 
334775e
 
 
 
 
 
 
 
 
 
1b98aa7
 
 
 
12ca412
 
1b98aa7
334775e
19fbe98
1b98aa7
 
a2b49fd
 
 
 
12ca412
 
 
 
 
5a833c3
12ca412
811dd6e
5a833c3
 
12ca412
 
 
 
 
5a833c3
12ca412
 
 
 
1b98aa7
 
12ca412
5a833c3
 
334775e
1b98aa7
 
5a833c3
334775e
 
 
 
 
 
5a833c3
99056a8
12ca412
 
 
 
 
334775e
 
 
12ca412
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import os
import re
from pathlib import Path
from typing import Optional

import requests

from .artifact import Artifact, Artifactory, reset_artifacts_cache
from .logging_utils import get_logger
from .version import version

logger = get_logger()
COLLECTION_SEPARATOR = "."
PATHS_SEP = ":"


class Catalog(Artifactory):
    name: str = None
    location: str = None


try:
    import unitxt

    if unitxt.__file__:
        lib_dir = os.path.dirname(unitxt.__file__)
    else:
        lib_dir = os.path.dirname(__file__)
except ImportError:
    lib_dir = os.path.dirname(__file__)

default_catalog_path = os.path.join(lib_dir, "catalog")


class LocalCatalog(Catalog):
    name: str = "local"
    location: str = default_catalog_path

    def path(self, artifact_identifier: str):
        assert (
            artifact_identifier.strip()
        ), "artifact_identifier should not be an empty string."
        parts = artifact_identifier.split(COLLECTION_SEPARATOR)
        parts[-1] = parts[-1] + ".json"
        return os.path.join(self.location, *parts)

    def load(self, artifact_identifier: str):
        assert (
            artifact_identifier in self
        ), f"Artifact with name {artifact_identifier} does not exist"
        path = self.path(artifact_identifier)
        return Artifact.load(path)

    def __getitem__(self, name) -> Artifact:
        return self.load(name)

    def __contains__(self, artifact_identifier: str):
        if not os.path.exists(self.location):
            return False
        path = self.path(artifact_identifier)
        if path is None:
            return False
        return os.path.exists(path) and os.path.isfile(path)

    def save_artifact(
        self,
        artifact: Artifact,
        artifact_identifier: str,
        overwrite: bool = False,
        verbose: bool = True,
    ):
        assert isinstance(
            artifact, Artifact
        ), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
        if not overwrite:
            assert (
                artifact_identifier not in self
            ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
        path = self.path(artifact_identifier)
        os.makedirs(Path(path).parent.absolute(), exist_ok=True)
        artifact.save(path)
        if verbose:
            logger.info(f"Artifact {artifact_identifier} saved to {path}")


class EnvironmentLocalCatalog(LocalCatalog):
    pass


class GithubCatalog(LocalCatalog):
    name = "community"
    repo = "unitxt"
    repo_dir = "src/unitxt/catalog"
    user = "IBM"

    def prepare(self):
        tag = version
        self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"

    def load(self, artifact_identifier: str):
        url = self.path(artifact_identifier)
        response = requests.get(url)
        data = response.json()
        return Artifact.from_dict(data)

    def __contains__(self, artifact_identifier: str):
        url = self.path(artifact_identifier)
        response = requests.head(url)
        return response.status_code == 200


def verify_legal_catalog_name(name):
    assert re.match(
        r"^[\w" + COLLECTION_SEPARATOR + "]+$", name
    ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'


def add_to_catalog(
    artifact: Artifact,
    name: str,
    catalog: Catalog = None,
    overwrite: bool = False,
    catalog_path: Optional[str] = None,
    verbose=True,
):
    reset_artifacts_cache()
    if catalog is None:
        if catalog_path is None:
            catalog_path = default_catalog_path
        catalog = LocalCatalog(location=catalog_path)
    verify_legal_catalog_name(name)
    catalog.save_artifact(
        artifact, name, overwrite=overwrite, verbose=verbose
    )  # remove collection (its actually the dir).
    # verify name