Elron commited on
Commit
12ca412
1 Parent(s): 6f3c593

Upload catalog.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. catalog.py +51 -38
catalog.py CHANGED
@@ -1,7 +1,13 @@
1
  import os
 
 
 
 
 
2
 
3
- from .artifact import Artifact, Artifactory, register_atrifactory
4
- from .file_utils import get_all_files_in_dir
 
5
 
6
 
7
  class Catalog(Artifactory):
@@ -12,25 +18,17 @@ class Catalog(Artifactory):
12
  try:
13
  import unitxt
14
 
15
- catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
16
  except ImportError:
17
- catalog_path = os.path.dirname(__file__) + "/catalog"
18
 
19
 
20
  class LocalCatalog(Catalog):
21
  name: str = "local"
22
- location: str = catalog_path
23
-
24
- @property
25
- def path_dict(self):
26
- result = {}
27
- for path in get_all_files_in_dir(self.location, recursive=True, file_extension=".json"):
28
- name = os.path.splitext(os.path.basename(path))[0]
29
- result[name] = path
30
- return result
31
 
32
  def path(self, artifact_identifier: str):
33
- return self.path_dict.get(artifact_identifier, None)
34
 
35
  def load(self, artifact_identifier: str):
36
  assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
@@ -49,41 +47,56 @@ class LocalCatalog(Catalog):
49
  return False
50
  return os.path.exists(path) and os.path.isfile(path)
51
 
52
- def save(self, artifact: Artifact, artifact_identifier: str, collection: str, overwrite: bool = False):
53
- assert isinstance(artifact, Artifact), "Artifact must be an instance of Artifact"
54
 
 
 
 
55
  if not overwrite:
56
  assert (
57
  artifact_identifier not in self
58
  ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
59
-
60
- collection_dir = os.path.join(self.location, collection)
61
- os.makedirs(collection_dir, exist_ok=True)
62
- path = os.path.join(collection_dir, artifact_identifier + ".json")
63
  artifact.save(path)
64
 
65
 
66
- register_atrifactory(LocalCatalog())
67
-
68
- try:
69
- import unitxt
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- library_catalog = LocalCatalog("library", unitxt.__path__[0] + "/catalog")
72
- register_atrifactory(library_catalog)
73
- except:
74
- pass
75
- # create a catalog for the community
76
 
 
 
 
77
 
78
- class CommunityCatalog(Catalog):
79
- name = "community"
80
- location = "https://raw.githubusercontent.com/unitxt/unitxt/main/catalog/community.json"
81
 
82
- def load(self, artifact_identifier: str):
83
- pass
 
 
 
 
 
 
 
84
 
85
 
86
- def add_to_catalog(artifact: Artifact, name: str, collection=str, catalog: Catalog = None, overwrite: bool = False):
87
- if catalog is None:
88
- catalog = LocalCatalog()
89
- catalog.save(artifact, name, collection, overwrite=overwrite)
 
1
  import os
2
+ import re
3
+ from pathlib import Path
4
+ import requests
5
+ import json
6
+ from .artifact import Artifact, Artifactory
7
 
8
+
9
+ COLLECTION_SEPARATOR = '::'
10
+ PATHS_SEP = ':'
11
 
12
 
13
  class Catalog(Artifactory):
 
18
  try:
19
  import unitxt
20
 
21
+ default_catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
22
  except ImportError:
23
+ default_catalog_path = os.path.dirname(__file__) + "/catalog"
24
 
25
 
26
  class LocalCatalog(Catalog):
27
  name: str = "local"
28
+ location: str = default_catalog_path
 
 
 
 
 
 
 
 
29
 
30
  def path(self, artifact_identifier: str):
31
+ return os.path.join(self.location, *(artifact_identifier + ".json").split(COLLECTION_SEPARATOR))
32
 
33
  def load(self, artifact_identifier: str):
34
  assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
 
47
  return False
48
  return os.path.exists(path) and os.path.isfile(path)
49
 
 
 
50
 
51
+
52
+ def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
53
+ assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
54
  if not overwrite:
55
  assert (
56
  artifact_identifier not in self
57
  ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
58
+ path = self.path(artifact_identifier)
59
+ os.makedirs(Path(path).parent.absolute(), exist_ok=True)
 
 
60
  artifact.save(path)
61
 
62
 
63
+ class GithubCatalog(LocalCatalog):
64
+ name = "community"
65
+ repo = "unitxt"
66
+ repo_dir = "src/unitxt/catalog"
67
+ user = "IBM"
68
+ branch = "master"
69
+
70
+ def prepare(self):
71
+ self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{self.branch}/{self.repo_dir}"
72
+
73
+ def load(self, artifact_identifier: str):
74
+ url = self.path(artifact_identifier)
75
+ response = requests.get(url)
76
+ data = response.json()
77
+ return Artifact.from_dict(data)
78
+
79
+ def __contains__(self, artifact_identifier: str):
80
+ url = self.path(artifact_identifier)
81
+ response = requests.head(url)
82
+ return response.status_code == 200
83
+
84
+
85
 
 
 
 
 
 
86
 
87
+ def verify_legal_catalog_name(name):
88
+ assert re.match('^[\w' + COLLECTION_SEPARATOR + ']+$', name),\
89
+ 'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'
90
 
 
 
 
91
 
92
+ def add_to_catalog(artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False,
93
+ catalog_path: str = None):
94
+ if catalog is None:
95
+ if catalog_path is None:
96
+ catalog_path = default_catalog_path
97
+ catalog = LocalCatalog(location=catalog_path)
98
+ verify_legal_catalog_name(name)
99
+ catalog.save_artifact(artifact, name, overwrite=overwrite) # remove collection (its actually the dir).
100
+ # verify name
101
 
102