File size: 1,796 Bytes
b5ae7e6
 
 
36f3d38
 
b5ae7e6
 
 
36f3d38
 
 
 
 
 
 
 
 
 
b5ae7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36f3d38
b5ae7e6
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import inspect
import os
import importlib
import inspect

from . import blocks
from .artifact import Artifact
from .utils import Singleton

def register_blocks():
    # Iterate over every object in the blocks module
    for name, obj in inspect.getmembers(blocks):
        # Make sure the object is a class
        if inspect.isclass(obj):
            # Make sure the class is a subclass of Artifact (but not Artifact itself)
            if issubclass(obj, Artifact) and obj is not Artifact:
                Artifact.register_class(obj)

# Usage
non_registered_files = ['__init__.py', 'artifact.py', 'utils.py', 'register.py', 'metric.py', 'dataset.py', 'blocks.py']

def _register_all_artifacts():
    
    dir = os.path.dirname(__file__)
    file_name = os.path.basename(__file__)
    
    for file in os.listdir(dir):
        if file.endswith('.py') and file not in non_registered_files and file != file_name:
            module_name = file.replace('.py', '')
            
            module = importlib.import_module('.' + module_name, __package__)
            
            for name, obj in inspect.getmembers(module):
                # Make sure the object is a class
                if inspect.isclass(obj):
                    # Make sure the class is a subclass of Artifact (but not Artifact itself)
                    if issubclass(obj, Artifact) and obj is not Artifact:
                        Artifact.register_class(obj)
            


class ProjectArtifactRegisterer():
    
    def __init__(self):
        
        if not hasattr(self, '_registered'):
            self._registered = False
        
        if not self._registered:
            _register_all_artifacts()
            self._registered = True
            

def register_all_artifacts():
    ProjectArtifactRegisterer()