Elron commited on
Commit
66c1161
·
1 Parent(s): 852c364

Upload artifact.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. artifact.py +28 -10
artifact.py CHANGED
@@ -4,8 +4,8 @@ import os
4
  import pkgutil
5
  from abc import ABC, abstractmethod
6
  from dataclasses import asdict, dataclass, field, fields
7
- from typing import final
8
-
9
  from .text_utils import camel_to_snake_case, is_camel_case
10
 
11
 
@@ -31,7 +31,7 @@ class Artifactories(object):
31
  assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
32
  assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
33
  assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
34
- self.artifactories.append(artifactory)
35
 
36
 
37
  class BaseArtifact(ABC):
@@ -89,14 +89,12 @@ class BaseArtifact(ABC):
89
  self._args_dict = asdict(self)
90
 
91
  for field in fields(self):
92
- # check if field.type is class and if it is subclass of BaseArtifact
93
- if isinstance(field.type, type) and issubclass(field.type, BaseArtifact):
 
94
  value = getattr(self, field.name)
95
- if isinstance(value, str):
96
- artifact, artifactory = fetch_artifact(value)
97
- assert artifact is not None, f"Artifact {value} does not exist, in {Artifactories()}"
98
- print(f"Artifact {value} is fetched from {artifactory}")
99
- setattr(self, field.name, artifact)
100
 
101
  self.prepare()
102
  self.verify()
@@ -145,6 +143,16 @@ class BaseArtifact(ABC):
145
  # cls = cls._class_register[d.pop('type')]
146
  # return cls(**d)
147
 
 
 
 
 
 
 
 
 
 
 
148
 
149
  class Artifact(BaseArtifact):
150
  type: str = field(init=False)
@@ -185,6 +193,16 @@ def fetch_artifact(name):
185
 
186
  raise UnitxtArtifactNotFoundError(name, Artifactories().artifactories)
187
 
 
 
 
 
 
 
 
 
 
 
188
 
189
  def register_all_artifacts(path):
190
  for loader, module_name, is_pkg in pkgutil.walk_packages(path):
 
4
  import pkgutil
5
  from abc import ABC, abstractmethod
6
  from dataclasses import asdict, dataclass, field, fields
7
+ from typing import final, Any, Dict, List, Union
8
+ from .type_utils import issubtype
9
  from .text_utils import camel_to_snake_case, is_camel_case
10
 
11
 
 
31
  assert isinstance(artifactory, Artifactory), "Artifactory must be an instance of Artifactory"
32
  assert hasattr(artifactory, "__contains__"), "Artifactory must have __contains__ method"
33
  assert hasattr(artifactory, "__getitem__"), "Artifactory must have __getitem__ method"
34
+ self.artifactories = [artifactory] + self.artifactories
35
 
36
 
37
  class BaseArtifact(ABC):
 
89
  self._args_dict = asdict(self)
90
 
91
  for field in fields(self):
92
+ if getattr(self, field.name) == 'cards.wnli':
93
+ print('cards.wnli')
94
+ if issubtype(field.type, Union[BaseArtifact, List[BaseArtifact], Dict[str, BaseArtifact]]):
95
  value = getattr(self, field.name)
96
+ value = map_values_in_place(value, maybe_recover_artifact)
97
+ setattr(self, field.name, value)
 
 
 
98
 
99
  self.prepare()
100
  self.verify()
 
143
  # cls = cls._class_register[d.pop('type')]
144
  # return cls(**d)
145
 
146
+ def map_values_in_place(object, mapper):
147
+ if isinstance(object, dict):
148
+ for key, value in object.items():
149
+ object[key] = mapper(value)
150
+ return object
151
+ if isinstance(object, list):
152
+ for i in range(len(object)):
153
+ object[i] = mapper(object[i])
154
+ return object
155
+ return mapper(object)
156
 
157
  class Artifact(BaseArtifact):
158
  type: str = field(init=False)
 
193
 
194
  raise UnitxtArtifactNotFoundError(name, Artifactories().artifactories)
195
 
196
+ def verbosed_fetch_artifact(identifer):
197
+ artifact, artifactory = fetch_artifact(identifer)
198
+ print(f"Artifact {identifer} is fetched from {artifactory}")
199
+ return artifact
200
+
201
+ def maybe_recover_artifact(artifact):
202
+ if isinstance(artifact, str):
203
+ return verbosed_fetch_artifact(artifact)
204
+ else:
205
+ return artifact
206
 
207
  def register_all_artifacts(path):
208
  for loader, module_name, is_pkg in pkgutil.walk_packages(path):