Elron commited on
Commit
eac4eaf
1 Parent(s): 64bbd46

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +168 -72
artifact.py CHANGED
@@ -3,19 +3,17 @@ import json
3
  import os
4
  import pkgutil
5
  from abc import ABC, abstractmethod
6
- from dataclasses import asdict, dataclass, field, fields
7
- from typing import final, Any, Dict, List, Union
8
- from .type_utils import issubtype
9
- from .text_utils import camel_to_snake_case, is_camel_case
10
-
11
 
12
- class AbstractField:
13
- pass
 
14
 
15
 
16
  class Artifactories(object):
17
  def __new__(cls):
18
- if not hasattr(cls, 'instance'):
19
  cls.instance = super(Artifactories, cls).__new__(cls)
20
  cls.instance.artifactories = []
21
 
@@ -34,7 +32,131 @@ class Artifactories(object):
34
  self.artifactories = [artifactory] + self.artifactories
35
 
36
 
37
- class BaseArtifact(ABC):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  _class_register = {}
39
 
40
  @classmethod
@@ -43,7 +165,9 @@ class BaseArtifact(ABC):
43
 
44
  @classmethod
45
  def register_class(cls, artifact_class):
46
- assert issubclass(artifact_class, BaseArtifact), "Artifact class must be a subclass of BaseArtifact"
 
 
47
  assert is_camel_case(
48
  artifact_class.__name__
49
  ), f"Artifact class name must be legal camel case, got {artifact_class.__name__}"
@@ -67,49 +191,6 @@ class BaseArtifact(ABC):
67
  d = json.load(f)
68
  return cls.is_artifact_dict(d)
69
 
70
- @final
71
- def __init__(self, *args, **kwargs):
72
- super().__init__(*args, **kwargs)
73
-
74
- @final
75
- def __init_subclass__(cls, **kwargs):
76
- super().__init_subclass__(**kwargs)
77
- cls = dataclass(cls)
78
-
79
- def prepare(self):
80
- pass
81
-
82
- def verify(self):
83
- pass
84
-
85
- @final
86
- def __post_init__(self):
87
- self.type = self.register_class(self.__class__)
88
-
89
- self._args_dict = asdict(self)
90
-
91
- for field in fields(self):
92
- if getattr(self, field.name) == 'cards.wnli':
93
- print('cards.wnli')
94
- if issubtype(field.type, Union[BaseArtifact, List[BaseArtifact], Dict[str, BaseArtifact]]):
95
- value = getattr(self, field.name)
96
- value = map_values_in_place(value, maybe_recover_artifact)
97
- setattr(self, field.name, value)
98
-
99
- self.prepare()
100
- self.verify()
101
-
102
- def to_dict(self):
103
- return self._args_dict
104
-
105
- def save(self, path):
106
- with open(path, "w") as f:
107
- json.dump(self.to_dict(), f, indent=4)
108
-
109
- # def __getstate__(self):
110
- # print('getstate', self.__dict__)
111
- # return self.to_dict()
112
-
113
  @classmethod
114
  def _recursive_load(cls, d):
115
  if isinstance(d, dict):
@@ -139,23 +220,35 @@ class BaseArtifact(ABC):
139
 
140
  assert "type" in d, "Saved artifact must have a type field"
141
  return cls._recursive_load(d)
142
- # assert d['type'] in cls._class_register, f'Artifact type "{d["type"]}" is not registered'
143
- # cls = cls._class_register[d.pop('type')]
144
- # return cls(**d)
145
 
146
- def map_values_in_place(object, mapper):
147
- if isinstance(object, dict):
148
- for key, value in object.items():
149
- object[key] = mapper(value)
150
- return object
151
- if isinstance(object, list):
152
- for i in range(len(object)):
153
- object[i] = mapper(object[i])
154
- return object
155
- return mapper(object)
156
 
157
- class Artifact(BaseArtifact):
158
- type: str = field(init=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  class ArtifactList(list, Artifact):
@@ -164,7 +257,7 @@ class ArtifactList(list, Artifact):
164
  artifact.prepare()
165
 
166
 
167
- class Artifactory(Artifact, ABC):
168
  @abstractmethod
169
  def __contains__(self, name: str) -> bool:
170
  pass
@@ -193,10 +286,12 @@ def fetch_artifact(name):
193
 
194
  raise UnitxtArtifactNotFoundError(name, Artifactories().artifactories)
195
 
 
196
  def verbosed_fetch_artifact(identifer):
197
- artifact, artifactory = fetch_artifact(identifer)
198
- print(f"Artifact {identifer} is fetched from {artifactory}")
199
- return artifact
 
200
 
201
  def maybe_recover_artifact(artifact):
202
  if isinstance(artifact, str):
@@ -204,6 +299,7 @@ def maybe_recover_artifact(artifact):
204
  else:
205
  return artifact
206
 
 
207
  def register_all_artifacts(path):
208
  for loader, module_name, is_pkg in pkgutil.walk_packages(path):
209
  print(__name__)
 
3
  import os
4
  import pkgutil
5
  from abc import ABC, abstractmethod
6
+ from copy import deepcopy
7
+ from typing import Any, Dict, List, Union, final
 
 
 
8
 
9
+ from .dataclass import Dataclass, Field, asdict, fields
10
+ from .text_utils import camel_to_snake_case, is_camel_case
11
+ from .type_utils import issubtype
12
 
13
 
14
  class Artifactories(object):
15
  def __new__(cls):
16
+ if not hasattr(cls, "instance"):
17
  cls.instance = super(Artifactories, cls).__new__(cls)
18
  cls.instance.artifactories = []
19
 
 
32
  self.artifactories = [artifactory] + self.artifactories
33
 
34
 
35
+ # class BaseArtifact(ABC):
36
+ # _class_register = {}
37
+
38
+ # @classmethod
39
+ # def is_artifact_dict(cls, d):
40
+ # return isinstance(d, dict) and "type" in d and d["type"] in cls._class_register
41
+
42
+ # @classmethod
43
+ # def register_class(cls, artifact_class):
44
+ # assert issubclass(artifact_class, BaseArtifact), "Artifact class must be a subclass of BaseArtifact"
45
+ # assert is_camel_case(
46
+ # artifact_class.__name__
47
+ # ), f"Artifact class name must be legal camel case, got {artifact_class.__name__}"
48
+
49
+ # snake_case_key = camel_to_snake_case(artifact_class.__name__)
50
+
51
+ # if snake_case_key in cls._class_register:
52
+ # assert (
53
+ # cls._class_register[snake_case_key] == artifact_class
54
+ # ), f"Artifact class name must be unique, {snake_case_key} already exists for {cls._class_register[snake_case_key]}"
55
+
56
+ # cls._class_register[snake_case_key] = artifact_class
57
+
58
+ # return snake_case_key
59
+
60
+ # @classmethod
61
+ # def is_artifact_file(cls, path):
62
+ # if not os.path.exists(path) or not os.path.isfile(path):
63
+ # return False
64
+ # with open(path, "r") as f:
65
+ # d = json.load(f)
66
+ # return cls.is_artifact_dict(d)
67
+
68
+ # @final
69
+ # def __init__(self, *args, **kwargs):
70
+ # super().__init__(*args, **kwargs)
71
+
72
+ # @final
73
+ # def __init_subclass__(cls, **kwargs):
74
+ # super().__init_subclass__(**kwargs)
75
+ # cls = dataclass(cls)
76
+
77
+ # def prepare(self):
78
+ # pass
79
+
80
+ # def verify(self):
81
+ # pass
82
+
83
+ # @final
84
+ # def __post_init__(self):
85
+ # self.type = self.register_class(self.__class__)
86
+
87
+ # self._args_dict = asdict(self)
88
+
89
+ # for field in fields(self):
90
+ # if getattr(self, field.name) == "cards.wnli":
91
+ # print("cards.wnli")
92
+ # if issubtype(field.type, Union[BaseArtifact, List[BaseArtifact], Dict[str, BaseArtifact]]):
93
+ # value = getattr(self, field.name)
94
+ # value = map_values_in_place(value, maybe_recover_artifact)
95
+ # setattr(self, field.name, value)
96
+
97
+ # self.prepare()
98
+ # self.verify()
99
+
100
+ # def to_dict(self):
101
+ # return self._args_dict
102
+
103
+ # def save(self, path):
104
+ # with open(path, "w") as f:
105
+ # json.dump(self.to_dict(), f, indent=4)
106
+
107
+ # # def __getstate__(self):
108
+ # # print('getstate', self.__dict__)
109
+ # # return self.to_dict()
110
+
111
+ # @classmethod
112
+ # def _recursive_load(cls, d):
113
+ # if isinstance(d, dict):
114
+ # new_d = {}
115
+ # for key, value in d.items():
116
+ # new_d[key] = cls._recursive_load(value)
117
+ # d = new_d
118
+ # elif isinstance(d, list):
119
+ # d = [cls._recursive_load(value) for value in d]
120
+ # else:
121
+ # pass
122
+ # if cls.is_artifact_dict(d):
123
+ # instance = cls._class_register[d.pop("type")](**d)
124
+ # return instance
125
+ # else:
126
+ # return d
127
+
128
+ # @classmethod
129
+ # def from_dict(cls, d):
130
+ # assert cls.is_artifact_dict(d), "Input must be a dict with type field"
131
+ # return cls._recursive_load(d)
132
+
133
+ # @classmethod
134
+ # def load(cls, path):
135
+ # with open(path, "r") as f:
136
+ # d = json.load(f)
137
+
138
+ # assert "type" in d, "Saved artifact must have a type field"
139
+ # return cls._recursive_load(d)
140
+ # # assert d['type'] in cls._class_register, f'Artifact type "{d["type"]}" is not registered'
141
+ # # cls = cls._class_register[d.pop('type')]
142
+ # # return cls(**d)
143
+
144
+
145
+ def map_values_in_place(object, mapper):
146
+ if isinstance(object, dict):
147
+ for key, value in object.items():
148
+ object[key] = mapper(value)
149
+ return object
150
+ if isinstance(object, list):
151
+ for i in range(len(object)):
152
+ object[i] = mapper(object[i])
153
+ return object
154
+ return mapper(object)
155
+
156
+
157
+ class Artifact(Dataclass):
158
+ type: str = Field(default=None, final=True, init=False)
159
+
160
  _class_register = {}
161
 
162
  @classmethod
 
165
 
166
  @classmethod
167
  def register_class(cls, artifact_class):
168
+ assert issubclass(
169
+ artifact_class, Artifact
170
+ ), f"Artifact class must be a subclass of Artifact, got {artifact_class}"
171
  assert is_camel_case(
172
  artifact_class.__name__
173
  ), f"Artifact class name must be legal camel case, got {artifact_class.__name__}"
 
191
  d = json.load(f)
192
  return cls.is_artifact_dict(d)
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
  @classmethod
195
  def _recursive_load(cls, d):
196
  if isinstance(d, dict):
 
220
 
221
  assert "type" in d, "Saved artifact must have a type field"
222
  return cls._recursive_load(d)
 
 
 
223
 
224
+ def prepare(self):
225
+ pass
 
 
 
 
 
 
 
 
226
 
227
+ def verify(self):
228
+ pass
229
+
230
+ @final
231
+ def __post_init__(self):
232
+ self.type = self.register_class(self.__class__)
233
+
234
+ self._init_dict = asdict(self)
235
+
236
+ for field in fields(self):
237
+ if issubtype(field.type, Union[Artifact, List[Artifact], Dict[str, Artifact]]):
238
+ value = getattr(self, field.name)
239
+ value = map_values_in_place(value, maybe_recover_artifact)
240
+ setattr(self, field.name, value)
241
+
242
+ self.prepare()
243
+ self.verify()
244
+
245
+ def get_init_dict(self):
246
+ return self._init_dict
247
+
248
+ def save(self, path):
249
+ with open(path, "w") as f:
250
+ init_dict = self.get_init_dict()
251
+ json.dump(init_dict, f, indent=4)
252
 
253
 
254
  class ArtifactList(list, Artifact):
 
257
  artifact.prepare()
258
 
259
 
260
+ class Artifactory(Artifact):
261
  @abstractmethod
262
  def __contains__(self, name: str) -> bool:
263
  pass
 
286
 
287
  raise UnitxtArtifactNotFoundError(name, Artifactories().artifactories)
288
 
289
+
290
  def verbosed_fetch_artifact(identifer):
291
+ artifact, artifactory = fetch_artifact(identifer)
292
+ print(f"Artifact {identifer} is fetched from {artifactory}")
293
+ return artifact
294
+
295
 
296
  def maybe_recover_artifact(artifact):
297
  if isinstance(artifact, str):
 
299
  else:
300
  return artifact
301
 
302
+
303
  def register_all_artifacts(path):
304
  for loader, module_name, is_pkg in pkgutil.walk_packages(path):
305
  print(__name__)