Elron commited on
Commit
c6d1c21
1 Parent(s): 147cebb

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +194 -0
artifact.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, asdict, fields, field
2
+ from abc import ABC, abstractmethod
3
+ from typing import final
4
+ import re
5
+ import json
6
+ import os
7
+ import pkgutil
8
+ import inspect
9
+
10
+
11
+ class AbstractField:
12
+ pass
13
+
14
+
15
+ from .text_utils import camel_to_snake_case, is_camel_case
16
+
17
+ artifactories = []
18
+
19
+
20
+ class BaseArtifact(ABC):
21
+ _class_register = {}
22
+
23
+ @classmethod
24
+ def is_artifact_dict(cls, d):
25
+ return isinstance(d, dict) and "type" in d and d["type"] in cls._class_register
26
+
27
+ @classmethod
28
+ def register_class(cls, artifact_class):
29
+ assert issubclass(artifact_class, BaseArtifact), "Artifact class must be a subclass of BaseArtifact"
30
+ assert is_camel_case(
31
+ artifact_class.__name__
32
+ ), f"Artifact class name must be legal camel case, got {artifact_class.__name__}"
33
+
34
+ snake_case_key = camel_to_snake_case(artifact_class.__name__)
35
+
36
+ if snake_case_key in cls._class_register:
37
+ assert (
38
+ cls._class_register[snake_case_key] == artifact_class
39
+ ), f"Artifact class name must be unique, {snake_case_key} already exists for {cls._class_register[snake_case_key]}"
40
+
41
+ cls._class_register[snake_case_key] = artifact_class
42
+
43
+ return snake_case_key
44
+
45
+ @classmethod
46
+ def is_artifact_file(cls, path):
47
+ if not os.path.exists(path) or not os.path.isfile(path):
48
+ return False
49
+ with open(path, "r") as f:
50
+ d = json.load(f)
51
+ return cls.is_artifact_dict(d)
52
+
53
+ @final
54
+ def __init__(self, *args, **kwargs):
55
+ super().__init__(*args, **kwargs)
56
+
57
+ @final
58
+ def __init_subclass__(cls, **kwargs):
59
+ super().__init_subclass__(**kwargs)
60
+ cls = dataclass(cls)
61
+
62
+ def prepare(self):
63
+ pass
64
+
65
+ def verify(self):
66
+ pass
67
+
68
+ @final
69
+ def __post_init__(self):
70
+ self.type = self.register_class(self.__class__)
71
+
72
+ self._args_dict = asdict(self)
73
+
74
+ for field in fields(self):
75
+ # check if field.type is class and if it is subclass of BaseArtifact
76
+ if isinstance(field.type, type) and issubclass(field.type, BaseArtifact):
77
+ value = getattr(self, field.name)
78
+ if isinstance(value, str):
79
+ artifact, artifactory = fetch_artifact(value)
80
+ assert artifact is not None, f"Artifact {value} does not exist, in {artifactories}"
81
+ print(f"Artifact {value} is fetched from {artifactory}")
82
+ setattr(self, field.name, artifact)
83
+
84
+ self.prepare()
85
+ self.verify()
86
+
87
+ def to_dict(self):
88
+ return self._args_dict
89
+
90
+ def save(self, path):
91
+ with open(path, "w") as f:
92
+ json.dump(self.to_dict(), f, indent=4)
93
+
94
+ # def __getstate__(self):
95
+ # print('getstate', self.__dict__)
96
+ # return self.to_dict()
97
+
98
+ @classmethod
99
+ def _recursive_load(cls, d):
100
+ if isinstance(d, dict):
101
+ new_d = {}
102
+ for key, value in d.items():
103
+ new_d[key] = cls._recursive_load(value)
104
+ d = new_d
105
+ elif isinstance(d, list):
106
+ d = [cls._recursive_load(value) for value in d]
107
+ else:
108
+ pass
109
+ if cls.is_artifact_dict(d):
110
+ instance = cls._class_register[d.pop("type")](**d)
111
+ return instance
112
+ else:
113
+ return d
114
+
115
+ @classmethod
116
+ def from_dict(cls, d):
117
+ assert cls.is_artifact_dict(d), "Input must be a dict with type field"
118
+ return cls._recursive_load(d)
119
+
120
+ @classmethod
121
+ def load(cls, path):
122
+ with open(path, "r") as f:
123
+ d = json.load(f)
124
+
125
+ assert "type" in d, "Saved artifact must have a type field"
126
+ return cls._recursive_load(d)
127
+ # assert d['type'] in cls._class_register, f'Artifact type "{d["type"]}" is not registered'
128
+ # cls = cls._class_register[d.pop('type')]
129
+ # return cls(**d)
130
+
131
+
132
+ class Artifact(BaseArtifact):
133
+ type: str = field(init=False)
134
+
135
+
136
+ class ArtifactList(list, Artifact):
137
+ def prepare(self):
138
+ for artifact in self:
139
+ artifact.prepare()
140
+
141
+
142
+ class Artifactory(Artifact, ABC):
143
+ @abstractmethod
144
+ def __contains__(self, name: str) -> bool:
145
+ pass
146
+
147
+ @abstractmethod
148
+ def __getitem__(self, name) -> Artifact:
149
+ pass
150
+
151
+
152
+ class UnitxtArtifactNotFoundError(Exception):
153
+ def __init__(self, name, artifactories):
154
+ self.name = name
155
+ self.artifactories = artifactories
156
+
157
+ def __str__(self):
158
+ return f"Artifact {self.name} does not exist, in artifactories:{self.artifactories}"
159
+
160
+
161
+ def fetch_artifact(name):
162
+ if Artifact.is_artifact_file(name):
163
+ return Artifact.load(name), None
164
+ else:
165
+ for artifactory in artifactories:
166
+ if name in artifactory:
167
+ return artifactory[name], artifactory
168
+
169
+ raise UnitxtArtifactNotFoundError(name, artifactories)
170
+
171
+
172
+ def register_atrifactory(artifactory):
173
+ assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
174
+ assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
175
+ assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
176
+ artifactories.append(artifactory)
177
+
178
+
179
+ def register_all_artifacts(path):
180
+ for loader, module_name, is_pkg in pkgutil.walk_packages(path):
181
+ print(__name__)
182
+ if module_name == __name__:
183
+ continue
184
+ print(f"Loading {module_name}")
185
+ # Import the module
186
+ module = loader.find_module(module_name).load_module(module_name)
187
+
188
+ # Iterate over every object in the module
189
+ for name, obj in inspect.getmembers(module):
190
+ # Make sure the object is a class
191
+ if inspect.isclass(obj):
192
+ # Make sure the class is a subclass of Artifact (but not Artifact itself)
193
+ if issubclass(obj, Artifact) and obj is not Artifact:
194
+ print(obj)