File size: 6,631 Bytes
5147f0d
12ca412
63f8e48
0e631d9
12ca412
334775e
5a833c3
12ca412
5147f0d
0e631d9
 
 
 
 
a3d663f
0e631d9
19fbe98
0e631d9
63f8e48
811dd6e
12ca412
19fbe98
0e631d9
1b98aa7
 
 
 
 
 
5147f0d
1b98aa7
 
0e631d9
63f8e48
1b98aa7
 
334775e
 
 
0e631d9
90e8224
 
1b98aa7
0e631d9
334775e
 
 
1b98aa7
0e631d9
 
 
1b98aa7
 
 
 
0e631d9
 
 
1b98aa7
 
 
 
 
 
 
 
334775e
 
 
 
 
 
 
 
 
 
1b98aa7
 
 
 
12ca412
 
1b98aa7
334775e
19fbe98
1b98aa7
 
a2b49fd
 
 
 
12ca412
 
 
 
 
63f8e48
5a833c3
12ca412
811dd6e
5a833c3
 
0e631d9
12ca412
 
 
0e631d9
9a07ce0
 
5a833c3
12ca412
 
 
 
1b98aa7
 
12ca412
5a833c3
0e631d9
334775e
1b98aa7
 
5a833c3
334775e
 
 
 
 
 
5a833c3
a3d663f
12ca412
 
0e631d9
12ca412
 
334775e
 
 
12ca412
63f8e48
 
0e631d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f8e48
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0e631d9
63f8e48
0e631d9
 
 
63f8e48
 
0e631d9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
import os
import re
from collections import Counter
from functools import lru_cache
from pathlib import Path
from typing import Optional

import requests

from .artifact import (
    Artifact,
    Artifactories,
    Artifactory,
    get_artifactory_name_and_args,
    reset_artifacts_json_cache,
)
from .logging_utils import get_logger
from .settings_utils import get_constants
from .text_utils import print_dict
from .version import version

logger = get_logger()
constants = get_constants()


class Catalog(Artifactory):
    name: str = None
    location: str = None


class LocalCatalog(Catalog):
    name: str = "local"
    location: str = constants.default_catalog_path
    is_local: bool = True

    def path(self, artifact_identifier: str):
        assert (
            artifact_identifier.strip()
        ), "artifact_identifier should not be an empty string."
        parts = artifact_identifier.split(constants.catalog_hirarchy_sep)
        parts[-1] = parts[-1] + ".json"
        return os.path.join(self.location, *parts)

    def load(self, artifact_identifier: str, overwrite_args=None):
        assert (
            artifact_identifier in self
        ), f"Artifact with name {artifact_identifier} does not exist"
        path = self.path(artifact_identifier)
        return Artifact.load(
            path, artifact_identifier=artifact_identifier, overwrite_args=overwrite_args
        )

    def __getitem__(self, name) -> Artifact:
        return self.load(name)

    def get_with_overwrite(self, name, overwrite_args):
        return self.load(name, overwrite_args=overwrite_args)

    def __contains__(self, artifact_identifier: str):
        if not os.path.exists(self.location):
            return False
        path = self.path(artifact_identifier)
        if path is None:
            return False
        return os.path.exists(path) and os.path.isfile(path)

    def save_artifact(
        self,
        artifact: Artifact,
        artifact_identifier: str,
        overwrite: bool = False,
        verbose: bool = True,
    ):
        assert isinstance(
            artifact, Artifact
        ), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
        if not overwrite:
            assert (
                artifact_identifier not in self
            ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
        path = self.path(artifact_identifier)
        os.makedirs(Path(path).parent.absolute(), exist_ok=True)
        artifact.save(path)
        if verbose:
            logger.info(f"Artifact {artifact_identifier} saved to {path}")


class EnvironmentLocalCatalog(LocalCatalog):
    pass


class GithubCatalog(LocalCatalog):
    name = "community"
    repo = "unitxt"
    repo_dir = "src/unitxt/catalog"
    user = "IBM"
    is_local: bool = False

    def prepare(self):
        tag = version
        self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"

    def load(self, artifact_identifier: str, overwrite_args=None):
        url = self.path(artifact_identifier)
        response = requests.get(url)
        data = response.json()
        new_artifact = Artifact.from_dict(data, overwrite_args=overwrite_args)
        new_artifact.artifact_identifier = artifact_identifier
        return new_artifact

    def __contains__(self, artifact_identifier: str):
        url = self.path(artifact_identifier)
        response = requests.head(url)
        return response.status_code == 200


def verify_legal_catalog_name(name):
    assert re.match(
        r"^[\w" + constants.catalog_hirarchy_sep + "]+$", name
    ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'


def add_to_catalog(
    artifact: Artifact,
    name: str,
    catalog: Catalog = None,
    overwrite: bool = False,
    catalog_path: Optional[str] = None,
    verbose=True,
):
    reset_artifacts_json_cache()
    if catalog is None:
        if catalog_path is None:
            catalog_path = constants.default_catalog_path
        catalog = LocalCatalog(location=catalog_path)
    verify_legal_catalog_name(name)
    catalog.save_artifact(
        artifact, name, overwrite=overwrite, verbose=verbose
    )  # remove collection (its actually the dir).
    # verify name


@lru_cache(maxsize=None)
def get_from_catalog(
    name: str,
    catalog: Catalog = None,
    catalog_path: Optional[str] = None,
):
    if catalog_path is not None:
        catalog = LocalCatalog(location=catalog_path)

    if catalog is None:
        artifactories = None
    else:
        artifactories = [catalog]

    catalog, name, args = get_artifactory_name_and_args(
        name, artifactories=artifactories
    )

    return catalog.get_with_overwrite(
        name=name,
        overwrite_args=args,
    )


def get_local_catalogs_paths():
    result = []
    for artifactory in Artifactories():
        if isinstance(artifactory, LocalCatalog):
            if artifactory.is_local:
                result.append(artifactory.location)
    return result


def count_files_recursively(folder):
    file_count = 0
    for _, _, files in os.walk(folder):
        file_count += len(files)
    return file_count


def local_catalog_summary(catalog_path):
    result = {}

    for dir in os.listdir(catalog_path):
        if os.path.isdir(os.path.join(catalog_path, dir)):
            result[dir] = count_files_recursively(os.path.join(catalog_path, dir))

    return result


def summary():
    result = Counter()
    done = set()
    for local_catalog_path in get_local_catalogs_paths():
        if local_catalog_path not in done:
            result += Counter(local_catalog_summary(local_catalog_path))
        done.add(local_catalog_path)
    print_dict(result)
    return result


def ls(to_file=None):
    done = set()
    result = []
    for local_catalog_path in get_local_catalogs_paths():
        if local_catalog_path not in done:
            for root, _, files in os.walk(local_catalog_path):
                for file in files:
                    if ".json" not in file:
                        continue
                    file_path = os.path.relpath(
                        os.path.join(root, file), local_catalog_path
                    )
                    file_id = ".".join(
                        file_path.replace(".json", "").split(os.path.sep)
                    )
                    result.append(file_id)
    if to_file:
        with open(to_file, "w+") as f:
            f.write("\n".join(result))
    else:
        logger.info("\n".join(result))
    return result