rawalkhirodkar's picture
Add initial commit
28c256d
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import inspect
import threading
import warnings
from collections import OrderedDict
from typing import Type, TypeVar
_lock = threading.RLock()
T = TypeVar('T')
def _accquire_lock() -> None:
"""Acquire the module-level lock for serializing access to shared data.
This should be released with _release_lock().
"""
if _lock:
_lock.acquire()
def _release_lock() -> None:
"""Release the module-level lock acquired by calling _accquire_lock()."""
if _lock:
_lock.release()
class ManagerMeta(type):
"""The metaclass for global accessible class.
The subclasses inheriting from ``ManagerMeta`` will manage their
own ``_instance_dict`` and root instances. The constructors of subclasses
must contain the ``name`` argument.
Examples:
>>> class SubClass1(metaclass=ManagerMeta):
>>> def __init__(self, *args, **kwargs):
>>> pass
AssertionError: <class '__main__.SubClass1'>.__init__ must have the
name argument.
>>> class SubClass2(metaclass=ManagerMeta):
>>> def __init__(self, name):
>>> pass
>>> # valid format.
"""
def __init__(cls, *args):
cls._instance_dict = OrderedDict()
params = inspect.getfullargspec(cls)
params_names = params[0] if params[0] else []
assert 'name' in params_names, f'{cls} must have the `name` argument'
super().__init__(*args)
class ManagerMixin(metaclass=ManagerMeta):
"""``ManagerMixin`` is the base class for classes that have global access
requirements.
The subclasses inheriting from ``ManagerMixin`` can get their
global instances.
Examples:
>>> class GlobalAccessible(ManagerMixin):
>>> def __init__(self, name=''):
>>> super().__init__(name)
>>>
>>> GlobalAccessible.get_instance('name')
>>> instance_1 = GlobalAccessible.get_instance('name')
>>> instance_2 = GlobalAccessible.get_instance('name')
>>> assert id(instance_1) == id(instance_2)
Args:
name (str): Name of the instance. Defaults to ''.
"""
def __init__(self, name: str = '', **kwargs):
assert isinstance(name, str) and name, \
'name argument must be an non-empty string.'
self._instance_name = name
@classmethod
def get_instance(cls: Type[T], name: str, **kwargs) -> T:
"""Get subclass instance by name if the name exists.
If corresponding name instance has not been created, ``get_instance``
will create an instance, otherwise ``get_instance`` will return the
corresponding instance.
Examples
>>> instance1 = GlobalAccessible.get_instance('name1')
>>> # Create name1 instance.
>>> instance.instance_name
name1
>>> instance2 = GlobalAccessible.get_instance('name1')
>>> # Get name1 instance.
>>> assert id(instance1) == id(instance2)
Args:
name (str): Name of instance. Defaults to ''.
Returns:
object: Corresponding name instance, the latest instance, or root
instance.
"""
_accquire_lock()
assert isinstance(name, str), \
f'type of name should be str, but got {type(cls)}'
instance_dict = cls._instance_dict # type: ignore
# Get the instance by name.
if name not in instance_dict:
instance = cls(name=name, **kwargs) # type: ignore
instance_dict[name] = instance # type: ignore
elif kwargs:
warnings.warn(
f'{cls} instance named of {name} has been created, '
'the method `get_instance` should not accept any other '
'arguments')
# Get latest instantiated instance or root instance.
_release_lock()
return instance_dict[name]
@classmethod
def get_current_instance(cls):
"""Get latest created instance.
Before calling ``get_current_instance``, The subclass must have called
``get_instance(xxx)`` at least once.
Examples
>>> instance = GlobalAccessible.get_current_instance()
AssertionError: At least one of name and current needs to be set
>>> instance = GlobalAccessible.get_instance('name1')
>>> instance.instance_name
name1
>>> instance = GlobalAccessible.get_current_instance()
>>> instance.instance_name
name1
Returns:
object: Latest created instance.
"""
_accquire_lock()
if not cls._instance_dict:
raise RuntimeError(
f'Before calling {cls.__name__}.get_current_instance(), you '
'should call get_instance(name=xxx) at least once.')
name = next(iter(reversed(cls._instance_dict)))
_release_lock()
return cls._instance_dict[name]
@classmethod
def check_instance_created(cls, name: str) -> bool:
"""Check whether the name corresponding instance exists.
Args:
name (str): Name of instance.
Returns:
bool: Whether the name corresponding instance exists.
"""
return name in cls._instance_dict
@property
def instance_name(self) -> str:
"""Get the name of instance.
Returns:
str: Name of instance.
"""
return self._instance_name