Elron commited on
Commit
b5ae7e6
1 Parent(s): 118aaef

Upload register.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. register.py +42 -4
register.py CHANGED
@@ -1,8 +1,11 @@
1
- from .artifact import Artifact
2
- from . import blocks
3
-
4
  import inspect
5
 
 
 
 
6
 
7
  def register_blocks():
8
  # Iterate over every object in the blocks module
@@ -13,5 +16,40 @@ def register_blocks():
13
  if issubclass(obj, Artifact) and obj is not Artifact:
14
  Artifact.register_class(obj)
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- register_blocks()
 
 
1
+ import inspect
2
+ import os
3
+ import importlib
4
  import inspect
5
 
6
+ from . import blocks
7
+ from .artifact import Artifact
8
+ from .utils import Singleton
9
 
10
  def register_blocks():
11
  # Iterate over every object in the blocks module
 
16
  if issubclass(obj, Artifact) and obj is not Artifact:
17
  Artifact.register_class(obj)
18
 
19
+ # Usage
20
+ non_registered_files = ['__init__.py', 'artifact.py', 'utils.py', 'register.py', 'metric.py', 'dataset.py', 'blocks.py']
21
+
22
+ def _register_all_artifacts():
23
+
24
+ dir = os.path.dirname(__file__)
25
+ file_name = os.path.basename(__file__)
26
+
27
+ for file in os.listdir(dir):
28
+ if file.endswith('.py') and file not in non_registered_files and file != file_name:
29
+ module_name = file.replace('.py', '')
30
+
31
+ module = importlib.import_module('.' + module_name, __package__)
32
+
33
+ for name, obj in inspect.getmembers(module):
34
+ # Make sure the object is a class
35
+ if inspect.isclass(obj):
36
+ # Make sure the class is a subclass of Artifact (but not Artifact itself)
37
+ if issubclass(obj, Artifact) and obj is not Artifact:
38
+ Artifact.register_class(obj)
39
+
40
+
41
+
42
+ class ProjectArtifactRegisterer():
43
+
44
+ def __init__(self):
45
+
46
+ if not hasattr(self, '_registered'):
47
+ self._registered = False
48
+
49
+ if not self._registered:
50
+ _register_all_artifacts()
51
+ self._registered = True
52
+
53
 
54
+ def register_all_artifacts():
55
+ ProjectArtifactRegisterer()