Elron commited on
Commit
1b98aa7
·
1 Parent(s): e5a087b

Upload catalog.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. catalog.py +86 -0
catalog.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from unitxt.artifact import Artifact
2
+ from .artifact import Artifact, Artifactory, register_atrifactory
3
+ from abc import abstractmethod
4
+ from dataclasses import field
5
+ import os
6
+ from .file_utils import get_all_files_in_dir
7
+
8
+
9
+ class Catalog(Artifactory):
10
+ name: str = None
11
+ location: str = None
12
+
13
+
14
+ catalog_path = os.path.dirname(__file__) + "/catalog"
15
+
16
+
17
+ class LocalCatalog(Catalog):
18
+ name: str = "local"
19
+ location: str = catalog_path
20
+
21
+ @property
22
+ def path_dict(self):
23
+ result = {}
24
+ for path in get_all_files_in_dir(self.location, recursive=True, file_extension=".json"):
25
+ name = os.path.splitext(os.path.basename(path))[0]
26
+ result[name] = path
27
+ return result
28
+
29
+ def path(self, artifact_identifier: str):
30
+ return self.path_dict.get(artifact_identifier, None)
31
+
32
+ def load(self, artifact_identifier: str):
33
+ assert artifact_identifier in self, "Artifact with name {} does not exist".format(artifact_identifier)
34
+ path = self.path(artifact_identifier)
35
+ artifact_instance = Artifact.load(path)
36
+ return artifact_instance
37
+
38
+ def __getitem__(self, name) -> Artifact:
39
+ return self.load(name)
40
+
41
+ def __contains__(self, artifact_identifier: str):
42
+ if not os.path.exists(self.location):
43
+ return False
44
+ path = self.path(artifact_identifier)
45
+ if path is None:
46
+ return False
47
+ return os.path.exists(path) and os.path.isfile(path)
48
+
49
+ def save(self, artifact: Artifact, artifact_identifier: str, collection: str, overwrite: bool = False):
50
+ assert isinstance(artifact, Artifact), "Artifact must be an instance of Artifact"
51
+
52
+ if not overwrite:
53
+ assert (
54
+ artifact_identifier not in self
55
+ ), f"Artifact with name {artifact_identifier} already exists in catalog {self.name}"
56
+
57
+ collection_dir = os.path.join(self.location, collection)
58
+ os.makedirs(collection_dir, exist_ok=True)
59
+ path = os.path.join(collection_dir, artifact_identifier + ".json")
60
+ artifact.save(path)
61
+
62
+
63
+ register_atrifactory(LocalCatalog())
64
+
65
+ try:
66
+ import unitxt
67
+
68
+ library_catalog = LocalCatalog("library", unitxt.__path__[0] + "/catalog")
69
+ register_atrifactory(library_catalog)
70
+ except:
71
+ pass
72
+ # create a catalog for the community
73
+
74
+
75
+ class CommunityCatalog(Catalog):
76
+ name = "community"
77
+ location = "https://raw.githubusercontent.com/unitxt/unitxt/main/catalog/community.json"
78
+
79
+ def load(self, artifact_identifier: str):
80
+ pass
81
+
82
+
83
+ def add_to_catalog(artifact: Artifact, name: str, collection=str, catalog: Catalog = None, overwrite: bool = False):
84
+ if catalog is None:
85
+ catalog = LocalCatalog()
86
+ catalog.save(artifact, name, collection, overwrite=overwrite)