Elron commited on
Commit
5a833c3
1 Parent(s): a5c21f3

Upload catalog.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. catalog.py +26 -23
catalog.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import re
3
  from pathlib import Path
 
4
  import requests
5
- import json
6
- from .artifact import Artifact, Artifactory
7
 
 
 
8
 
9
- COLLECTION_SEPARATOR = '.'
10
- PATHS_SEP = ':'
11
 
12
 
13
  class Catalog(Artifactory):
@@ -18,9 +19,14 @@ class Catalog(Artifactory):
18
  try:
19
  import unitxt
20
 
21
- default_catalog_path = os.path.dirname(unitxt.__file__) + "/catalog"
 
 
 
22
  except ImportError:
23
- default_catalog_path = os.path.dirname(__file__) + "/catalog"
 
 
24
 
25
 
26
  class LocalCatalog(Catalog):
@@ -28,7 +34,7 @@ class LocalCatalog(Catalog):
28
  location: str = default_catalog_path
29
 
30
  def path(self, artifact_identifier: str):
31
- assert artifact_identifier.strip(), 'artifact_identifier should not be an empty string.'
32
  parts = artifact_identifier.split(COLLECTION_SEPARATOR)
33
  parts[-1] = parts[-1] + ".json"
34
  return os.path.join(self.location, *parts)
@@ -50,8 +56,6 @@ class LocalCatalog(Catalog):
50
  return False
51
  return os.path.exists(path) and os.path.isfile(path)
52
 
53
-
54
-
55
  def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
56
  assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
57
  if not overwrite:
@@ -61,6 +65,7 @@ class LocalCatalog(Catalog):
61
  path = self.path(artifact_identifier)
62
  os.makedirs(Path(path).parent.absolute(), exist_ok=True)
63
  artifact.save(path)
 
64
 
65
 
66
  class GithubCatalog(LocalCatalog):
@@ -68,38 +73,36 @@ class GithubCatalog(LocalCatalog):
68
  repo = "unitxt"
69
  repo_dir = "src/unitxt/catalog"
70
  user = "IBM"
71
- branch = "master"
72
-
73
  def prepare(self):
74
- self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{self.branch}/{self.repo_dir}"
75
-
 
76
  def load(self, artifact_identifier: str):
77
  url = self.path(artifact_identifier)
78
  response = requests.get(url)
79
  data = response.json()
80
  return Artifact.from_dict(data)
81
-
82
  def __contains__(self, artifact_identifier: str):
83
  url = self.path(artifact_identifier)
84
  response = requests.head(url)
85
  return response.status_code == 200
86
-
87
-
88
 
89
 
90
  def verify_legal_catalog_name(name):
91
- assert re.match('^[\w' + COLLECTION_SEPARATOR + ']+$', name),\
92
- 'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'
 
93
 
94
 
95
- def add_to_catalog(artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False,
96
- catalog_path: str = None):
 
97
  if catalog is None:
98
  if catalog_path is None:
99
  catalog_path = default_catalog_path
100
  catalog = LocalCatalog(location=catalog_path)
101
  verify_legal_catalog_name(name)
102
- catalog.save_artifact(artifact, name, overwrite=overwrite) # remove collection (its actually the dir).
103
  # verify name
104
-
105
-
 
1
  import os
2
  import re
3
  from pathlib import Path
4
+
5
  import requests
 
 
6
 
7
+ from ._version import get_current_version
8
+ from .artifact import Artifact, Artifactory
9
 
10
+ COLLECTION_SEPARATOR = "."
11
+ PATHS_SEP = ":"
12
 
13
 
14
  class Catalog(Artifactory):
 
19
  try:
20
  import unitxt
21
 
22
+ if unitxt.__file__:
23
+ lib_dir = os.path.dirname(unitxt.__file__)
24
+ else:
25
+ lib_dir = os.path.dirname(__file__)
26
  except ImportError:
27
+ lib_dir = os.path.dirname(__file__)
28
+
29
+ default_catalog_path = os.path.join(lib_dir, "catalog")
30
 
31
 
32
  class LocalCatalog(Catalog):
 
34
  location: str = default_catalog_path
35
 
36
  def path(self, artifact_identifier: str):
37
+ assert artifact_identifier.strip(), "artifact_identifier should not be an empty string."
38
  parts = artifact_identifier.split(COLLECTION_SEPARATOR)
39
  parts[-1] = parts[-1] + ".json"
40
  return os.path.join(self.location, *parts)
 
56
  return False
57
  return os.path.exists(path) and os.path.isfile(path)
58
 
 
 
59
  def save_artifact(self, artifact: Artifact, artifact_identifier: str, overwrite: bool = False):
60
  assert isinstance(artifact, Artifact), f"Input artifact must be an instance of Artifact, got {type(artifact)}"
61
  if not overwrite:
 
65
  path = self.path(artifact_identifier)
66
  os.makedirs(Path(path).parent.absolute(), exist_ok=True)
67
  artifact.save(path)
68
+ print(f"Artifact {artifact_identifier} saved to {path}")
69
 
70
 
71
  class GithubCatalog(LocalCatalog):
 
73
  repo = "unitxt"
74
  repo_dir = "src/unitxt/catalog"
75
  user = "IBM"
76
+
 
77
  def prepare(self):
78
+ tag = get_current_version().split("+")[0]
79
+ self.location = f"https://raw.githubusercontent.com/{self.user}/{self.repo}/{tag}/{self.repo_dir}"
80
+
81
  def load(self, artifact_identifier: str):
82
  url = self.path(artifact_identifier)
83
  response = requests.get(url)
84
  data = response.json()
85
  return Artifact.from_dict(data)
86
+
87
  def __contains__(self, artifact_identifier: str):
88
  url = self.path(artifact_identifier)
89
  response = requests.head(url)
90
  return response.status_code == 200
 
 
91
 
92
 
93
  def verify_legal_catalog_name(name):
94
+ assert re.match(
95
+ r"^[\w" + COLLECTION_SEPARATOR + "]+$", name
96
+ ), 'Catalog name should be alphanumeric, ":" should specify dirs (instead of "/").'
97
 
98
 
99
+ def add_to_catalog(
100
+ artifact: Artifact, name: str, catalog: Catalog = None, overwrite: bool = False, catalog_path: str = None
101
+ ):
102
  if catalog is None:
103
  if catalog_path is None:
104
  catalog_path = default_catalog_path
105
  catalog = LocalCatalog(location=catalog_path)
106
  verify_legal_catalog_name(name)
107
+ catalog.save_artifact(artifact, name, overwrite=overwrite) # remove collection (its actually the dir).
108
  # verify name