| try: |
| import namex |
| except ImportError: |
| namex = None |
|
|
|
|
| |
| |
| REGISTERED_NAMES_TO_OBJS = {} |
| REGISTERED_OBJS_TO_NAMES = {} |
|
|
|
|
| def register_internal_serializable(path, symbol): |
| global REGISTERED_NAMES_TO_OBJS |
| if isinstance(path, (list, tuple)): |
| name = path[0] |
| else: |
| name = path |
| REGISTERED_NAMES_TO_OBJS[name] = symbol |
| REGISTERED_OBJS_TO_NAMES[symbol] = name |
|
|
|
|
| def get_symbol_from_name(name): |
| return REGISTERED_NAMES_TO_OBJS.get(name, None) |
|
|
|
|
| def get_name_from_symbol(symbol): |
| return REGISTERED_OBJS_TO_NAMES.get(symbol, None) |
|
|
|
|
| if namex: |
|
|
| class keras_export(namex.export): |
| def __init__(self, path): |
| super().__init__(package="keras", path=path) |
|
|
| def __call__(self, symbol): |
| register_internal_serializable(self.path, symbol) |
| return super().__call__(symbol) |
|
|
| else: |
|
|
| class keras_export: |
| def __init__(self, path): |
| self.path = path |
|
|
| def __call__(self, symbol): |
| register_internal_serializable(self.path, symbol) |
| return symbol |
|
|