Elron commited on
Commit
0e631d9
1 Parent(s): 40f0408

Upload catalog.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. catalog.py +75 -25
catalog.py CHANGED
@@ -1,19 +1,26 @@
1
  import os
2
  import re
3
  from collections import Counter
 
4
  from pathlib import Path
5
  from typing import Optional
6
 
7
  import requests
8
 
9
- from .artifact import Artifact, Artifactories, Artifactory, reset_artifacts_cache
 
 
 
 
 
 
10
  from .logging_utils import get_logger
 
11
  from .text_utils import print_dict
12
  from .version import version
13
 
14
  logger = get_logger()
15
- COLLECTION_SEPARATOR = "."
16
- PATHS_SEP = ":"
17
 
18
 
19
  class Catalog(Artifactory):
@@ -21,42 +28,34 @@ class Catalog(Artifactory):
21
  location: str = None
22
 
23
 
24
- try:
25
- import unitxt
26
-
27
- if unitxt.__file__:
28
- lib_dir = os.path.dirname(unitxt.__file__)
29
- else:
30
- lib_dir = os.path.dirname(__file__)
31
- except ImportError:
32
- lib_dir = os.path.dirname(__file__)
33
-
34
- default_catalog_path = os.path.join(lib_dir, "catalog")
35
-
36
-
37
  class LocalCatalog(Catalog):
38
  name: str = "local"
39
- location: str = default_catalog_path
40
  is_local: bool = True
41
 
42
  def path(self, artifact_identifier: str):
43
  assert (
44
  artifact_identifier.strip()
45
  ), "artifact_identifier should not be an empty string."
46
- parts = artifact_identifier.split(COLLECTION_SEPARATOR)
47
  parts[-1] = parts[-1] + ".json"
48
  return os.path.join(self.location, *parts)
49
 
50
- def load(self, artifact_identifier: str):
51
  assert (
52
  artifact_identifier in self
53
  ), f"Artifact with name {artifact_identifier} does not exist"
54
  path = self.path(artifact_identifier)
55
- return Artifact.load(path, artifact_identifier)
 
 
56
 
57
  def __getitem__(self, name) -> Artifact:
58
  return self.load(name)
59
 
 
 
 
60
  def __contains__(self, artifact_identifier: str):
61
  if not os.path.exists(self.location):
62
  return False
@@ -101,11 +100,11 @@ class GithubCatalog(LocalCatalog):
101
  tag = version
102
  self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"
103
 
104
- def load(self, artifact_identifier: str):
105
  url = self.path(artifact_identifier)
106
  response = requests.get(url)
107
  data = response.json()
108
- new_artifact = Artifact.from_dict(data)
109
  new_artifact.artifact_identifier = artifact_identifier
110
  return new_artifact
111
 
@@ -117,7 +116,7 @@ class GithubCatalog(LocalCatalog):
117
 
118
  def verify_legal_catalog_name(name):
119
  assert re.match(
120
- r"^[\w" + COLLECTION_SEPARATOR + "]+$", name
121
  ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
122
 
123
 
@@ -132,7 +131,7 @@ def add_to_catalog(
132
  reset_artifacts_cache()
133
  if catalog is None:
134
  if catalog_path is None:
135
- catalog_path = default_catalog_path
136
  catalog = LocalCatalog(location=catalog_path)
137
  verify_legal_catalog_name(name)
138
  catalog.save_artifact(
@@ -141,6 +140,30 @@ def add_to_catalog(
141
  # verify name
142
 
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  def get_local_catalogs_paths():
145
  result = []
146
  for artifactory in Artifactories():
@@ -169,7 +192,34 @@ def local_catalog_summary(catalog_path):
169
 
170
  def summary():
171
  result = Counter()
 
172
  for local_catalog_path in get_local_catalogs_paths():
173
- result += Counter(local_catalog_summary(local_catalog_path))
 
 
174
  print_dict(result)
175
  return result
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
  import re
3
  from collections import Counter
4
+ from functools import lru_cache
5
  from pathlib import Path
6
  from typing import Optional
7
 
8
  import requests
9
 
10
+ from .artifact import (
11
+ Artifact,
12
+ Artifactories,
13
+ Artifactory,
14
+ get_artifactory_name_and_args,
15
+ reset_artifacts_cache,
16
+ )
17
  from .logging_utils import get_logger
18
+ from .settings_utils import get_constants
19
  from .text_utils import print_dict
20
  from .version import version
21
 
22
  logger = get_logger()
23
+ constants = get_constants()
 
24
 
25
 
26
  class Catalog(Artifactory):
 
28
  location: str = None
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  class LocalCatalog(Catalog):
32
  name: str = "local"
33
+ location: str = constants.default_catalog_path
34
  is_local: bool = True
35
 
36
  def path(self, artifact_identifier: str):
37
  assert (
38
  artifact_identifier.strip()
39
  ), "artifact_identifier should not be an empty string."
40
+ parts = artifact_identifier.split(constants.catalog_hirarchy_sep)
41
  parts[-1] = parts[-1] + ".json"
42
  return os.path.join(self.location, *parts)
43
 
44
+ def load(self, artifact_identifier: str, overwrite_args=None):
45
  assert (
46
  artifact_identifier in self
47
  ), f"Artifact with name {artifact_identifier} does not exist"
48
  path = self.path(artifact_identifier)
49
+ return Artifact.load(
50
+ path, artifact_identifier=artifact_identifier, overwrite_args=overwrite_args
51
+ )
52
 
53
  def __getitem__(self, name) -> Artifact:
54
  return self.load(name)
55
 
56
+ def get_with_overwrite(self, name, overwrite_args):
57
+ return self.load(name, overwrite_args=overwrite_args)
58
+
59
  def __contains__(self, artifact_identifier: str):
60
  if not os.path.exists(self.location):
61
  return False
 
100
  tag = version
101
  self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"
102
 
103
+ def load(self, artifact_identifier: str, overwrite_args=None):
104
  url = self.path(artifact_identifier)
105
  response = requests.get(url)
106
  data = response.json()
107
+ new_artifact = Artifact.from_dict(data, overwrite_args=overwrite_args)
108
  new_artifact.artifact_identifier = artifact_identifier
109
  return new_artifact
110
 
 
116
 
117
  def verify_legal_catalog_name(name):
118
  assert re.match(
119
+ r"^[\w" + constants.catalog_hirarchy_sep + "]+$", name
120
  ), f'Artifict name ("{name}") should be alphanumeric. Use "." for nesting (e.g. myfolder.my_artifact)'
121
 
122
 
 
131
  reset_artifacts_cache()
132
  if catalog is None:
133
  if catalog_path is None:
134
+ catalog_path = constants.default_catalog_path
135
  catalog = LocalCatalog(location=catalog_path)
136
  verify_legal_catalog_name(name)
137
  catalog.save_artifact(
 
140
  # verify name
141
 
142
 
143
+ @lru_cache(maxsize=None)
144
+ def get_from_catalog(
145
+ name: str,
146
+ catalog: Catalog = None,
147
+ catalog_path: Optional[str] = None,
148
+ ):
149
+ if catalog_path is not None:
150
+ catalog = LocalCatalog(location=catalog_path)
151
+
152
+ if catalog is None:
153
+ artifactories = None
154
+ else:
155
+ artifactories = [catalog]
156
+
157
+ catalog, name, args = get_artifactory_name_and_args(
158
+ name, artifactories=artifactories
159
+ )
160
+
161
+ return catalog.get_with_overwrite(
162
+ name=name,
163
+ overwrite_args=args,
164
+ )
165
+
166
+
167
  def get_local_catalogs_paths():
168
  result = []
169
  for artifactory in Artifactories():
 
192
 
193
  def summary():
194
  result = Counter()
195
+ done = set()
196
  for local_catalog_path in get_local_catalogs_paths():
197
+ if local_catalog_path not in done:
198
+ result += Counter(local_catalog_summary(local_catalog_path))
199
+ done.add(local_catalog_path)
200
  print_dict(result)
201
  return result
202
+
203
+
204
+ def ls(to_file=None):
205
+ done = set()
206
+ result = []
207
+ for local_catalog_path in get_local_catalogs_paths():
208
+ if local_catalog_path not in done:
209
+ for root, _, files in os.walk(local_catalog_path):
210
+ for file in files:
211
+ if ".json" not in file:
212
+ continue
213
+ file_path = os.path.relpath(
214
+ os.path.join(root, file), local_catalog_path
215
+ )
216
+ file_id = ".".join(
217
+ file_path.replace(".json", "").split(os.path.sep)
218
+ )
219
+ result.append(file_id)
220
+ if to_file:
221
+ with open(to_file, "w+") as f:
222
+ f.write("\n".join(result))
223
+ else:
224
+ logger.info("\n".join(result))
225
+ return result