File size: 3,508 Bytes
5147f0d
12ca412
 
5a833c3
12ca412
5147f0d
5a833c3
811dd6e
12ca412
5a833c3
 
1b98aa7
 
 
 
 
 
5147f0d
fa0fdab
 
5147f0d
5a833c3
 
 
 
fa0fdab
5a833c3
 
 
1b98aa7
5147f0d
1b98aa7
 
12ca412
1b98aa7
 
5a833c3
90e8224
 
 
1b98aa7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12ca412
 
1b98aa7
 
 
 
12ca412
 
1b98aa7
5a833c3
1b98aa7
 
12ca412
 
 
 
 
5a833c3
12ca412
811dd6e
5a833c3
 
12ca412
 
 
 
 
5a833c3
12ca412
 
 
 
1b98aa7
 
12ca412
5a833c3
 
 
1b98aa7
 
5a833c3
 
 
12ca412
 
 
 
 
5a833c3
12ca412
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import os
import re
from pathlib import Path

import requests

from .artifact import Artifact, Artifactory
from .version import version

COLLECTION_SEPARATOR = "."
PATHS_SEP = ":"


class Catalog(Artifactory):
    name: str = None
    location: str = None


try:
    import unitxt

    if unitxt.__file__:
        lib_dir = os.path.dirname(unitxt.__file__)
    else:
        lib_dir = os.path.dirname(__file__)
except ImportError:
    lib_dir = os.path.dirname(__file__)

default_catalog_path = os.path.join(lib_dir, "catalog")


class LocalCatalog(Catalog):
    name: str = "local"
    location: str = default_catalog_path

    def path(self, artifact_identifier: str):
        assert artifact_identifier.strip(), "artifact_identifier should not be an empty string."
        parts = artifact_identifier.split(COLLECTION_SEPARATOR)
        parts[-1] = parts[-1] + ".json"
        return os.path.join(self.location, *parts)

    def load(self, artifact_identifier: str):
        assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
        path = self.path(artifact_identifier)
        artifact_instance = Artifact.load(path)
        return artifact_instance

    def __getitem__(self, name) -> Artifact:
        return self.load(name)

    def __contains__(self, artifact_identifier: str):
        if not os.path.exists(self.location):
            return False
        path = self.path(artifact_identifier)
        if path is None:
            return False
        return os.path.exists(path) and os.path.isfile(path)

    def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
        assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
        if not overwrite:
            assert (
                artifact_identifier not in self
            ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
        path = self.path(artifact_identifier)
        os.makedirs(Path(path).parent.absolute(), exist_ok=True)
        artifact.save(path)
        print(f"Artifact {artifact_identifier} saved to {path}")


class GithubCatalog(LocalCatalog):
    name = "community"
    repo = "unitxt"
    repo_dir = "src/unitxt/catalog"
    user = "IBM"

    def prepare(self):
        tag = version
        self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"

    def load(self, artifact_identifier: str):
        url = self.path(artifact_identifier)
        response = requests.get(url)
        data = response.json()
        return Artifact.from_dict(data)

    def __contains__(self, artifact_identifier: str):
        url = self.path(artifact_identifier)
        response = requests.head(url)
        return response.status_code == 200


def verify_legal_catalog_name(name):
    assert re.match(
        r"^[\w" + COLLECTION_SEPARATOR + "]+$", name
    ), 'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'


def add_to_catalog(
    artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False, catalog_path: str = None
):
    if catalog is None:
        if catalog_path is None:
            catalog_path = default_catalog_path
        catalog = LocalCatalog(location=catalog_path)
    verify_legal_catalog_name(name)
    catalog.save_artifact(artifact, name, overwrite=overwrite)  # remove collection (its actually the dir).
    # verify name