| | |
| |
|
| | import copy |
| | import functools |
| | import inspect |
| | import weakref |
| | from collections import defaultdict |
| | from collections.abc import Iterable |
| | from contextlib import contextmanager |
| | from typing import Callable, Union |
| |
|
| | from .base import BaseTransform |
| |
|
| |
|
| | class cache_randomness: |
| | """Decorator that marks the method with random return value(s) in a |
| | transform class. |
| | |
| | This decorator is usually used together with the context-manager |
| | :func`:cache_random_params`. In this context, a decorated method will |
| | cache its return value(s) at the first time of being invoked, and always |
| | return the cached values when being invoked again. |
| | |
| | .. note:: |
| | Only an instance method can be decorated with ``cache_randomness``. |
| | """ |
| |
|
| | def __init__(self, func): |
| |
|
| | |
| | if not inspect.isfunction(func): |
| | raise TypeError('Unsupport callable to decorate with' |
| | '@cache_randomness.') |
| | func_args = inspect.getfullargspec(func).args |
| | if len(func_args) == 0 or func_args[0] != 'self': |
| | raise TypeError( |
| | '@cache_randomness should only be used to decorate ' |
| | 'instance methods (the first argument is ``self``).') |
| |
|
| | functools.update_wrapper(self, func) |
| | self.func = func |
| | self.instance_ref = None |
| |
|
| | def __set_name__(self, owner, name): |
| | |
| | if not hasattr(owner, '_methods_with_randomness'): |
| | setattr(owner, '_methods_with_randomness', []) |
| |
|
| | |
| | |
| | |
| | owner._methods_with_randomness.append(name) |
| |
|
| | def __call__(self, *args, **kwargs): |
| | |
| | |
| | instance = self.instance_ref() |
| | name = self.__name__ |
| |
|
| | |
| | |
| | cache_enabled = getattr(instance, '_cache_enabled', False) |
| |
|
| | if cache_enabled: |
| | |
| | |
| | |
| | if not hasattr(instance, '_cache'): |
| | setattr(instance, '_cache', {}) |
| |
|
| | if name not in instance._cache: |
| | instance._cache[name] = self.func(instance, *args, **kwargs) |
| | |
| | return instance._cache[name] |
| | else: |
| | |
| | if hasattr(instance, '_cache'): |
| | del instance._cache |
| | |
| | return self.func(instance, *args, **kwargs) |
| |
|
| | def __get__(self, obj, cls): |
| | self.instance_ref = weakref.ref(obj) |
| | |
| | |
| | |
| | return copy.copy(self) |
| |
|
| |
|
| | def avoid_cache_randomness(cls): |
| | """Decorator that marks a data transform class (subclass of |
| | :class:`BaseTransform`) prohibited from caching randomness. With this |
| | decorator, errors will be raised in following cases: |
| | |
| | 1. A method is defined in the class with the decorate |
| | `cache_randomness`; |
| | 2. An instance of the class is invoked with the context |
| | `cache_random_params`. |
| | |
| | A typical usage of `avoid_cache_randomness` is to decorate the data |
| | transforms with non-cacheable random behaviors (e.g., the random behavior |
| | can not be defined in a method, thus can not be decorated with |
| | `cache_randomness`). This is for preventing unintentinoal use of such data |
| | transforms within the context of caching randomness, which may lead to |
| | unexpected results. |
| | """ |
| |
|
| | |
| | assert issubclass(cls, BaseTransform) |
| |
|
| | |
| | if getattr(cls, '_methods_with_randomness', None): |
| | raise RuntimeError( |
| | f'Class {cls.__name__} decorated with ' |
| | '``avoid_cache_randomness`` should not have methods decorated ' |
| | 'with ``cache_randomness`` (invalid methods: ' |
| | f'{cls._methods_with_randomness})') |
| |
|
| | class AvoidCacheRandomness: |
| |
|
| | def __get__(self, obj, objtype=None): |
| | |
| | |
| | |
| | |
| | |
| | return objtype.__dict__.get('_avoid_cache_randomness', False) |
| |
|
| | cls.avoid_cache_randomness = AvoidCacheRandomness() |
| | cls._avoid_cache_randomness = True |
| |
|
| | return cls |
| |
|
| |
|
| | @contextmanager |
| | def cache_random_params(transforms: Union[BaseTransform, Iterable]): |
| | """Context-manager that enables the cache of return values of methods |
| | decorated with ``cache_randomness`` in transforms. |
| | |
| | In this mode, decorated methods will cache their return values on the |
| | first invoking, and always return the cached value afterward. This allow |
| | to apply random transforms in a deterministic way. For example, apply same |
| | transforms on multiple examples. See ``cache_randomness`` for more |
| | information. |
| | |
| | Args: |
| | transforms (BaseTransform|list[BaseTransform]): The transforms to |
| | enable cache. |
| | """ |
| |
|
| | |
| | |
| | key2method = dict() |
| |
|
| | |
| | |
| | |
| | key2counter: dict = defaultdict(int) |
| |
|
| | def _add_invoke_counter(obj, method_name): |
| | method = getattr(obj, method_name) |
| | key = f'{id(obj)}.{method_name}' |
| | key2method[key] = method |
| |
|
| | @functools.wraps(method) |
| | def wrapped(*args, **kwargs): |
| | key2counter[key] += 1 |
| | return method(*args, **kwargs) |
| |
|
| | return wrapped |
| |
|
| | def _add_invoke_checker(obj, method_name): |
| | |
| | |
| | method = getattr(obj, method_name) |
| | key = f'{id(obj)}.{method_name}' |
| | key2method[key] = method |
| |
|
| | @functools.wraps(method) |
| | def wrapped(*args, **kwargs): |
| | |
| | for name in obj._methods_with_randomness: |
| | key = f'{id(obj)}.{name}' |
| | key2counter[key] = 0 |
| |
|
| | output = method(*args, **kwargs) |
| |
|
| | for name in obj._methods_with_randomness: |
| | key = f'{id(obj)}.{name}' |
| | if key2counter[key] > 1: |
| | raise RuntimeError( |
| | 'The method decorated with ``cache_randomness`` ' |
| | 'should be invoked at most once during processing ' |
| | f'one data sample. The method {name} of {obj} has ' |
| | f'been invoked {key2counter[key]} times.') |
| | return output |
| |
|
| | return wrapped |
| |
|
| | def _start_cache(t: BaseTransform): |
| | |
| | if getattr(t, 'avoid_cache_randomness', False): |
| | raise RuntimeError( |
| | f'Class {t.__class__.__name__} decorated with ' |
| | '``avoid_cache_randomness`` is not allowed to be used with' |
| | ' ``cache_random_params`` (e.g. wrapped by ' |
| | '``ApplyToMultiple`` with ``share_random_params==True``).') |
| |
|
| | |
| | if not hasattr(t, '_methods_with_randomness'): |
| | return |
| |
|
| | |
| | setattr(t, '_cache_enabled', True) |
| |
|
| | |
| | if hasattr(t, '_methods_with_randomness'): |
| | setattr(t, 'transform', _add_invoke_checker(t, 'transform')) |
| | for name in getattr(t, '_methods_with_randomness'): |
| | setattr(t, name, _add_invoke_counter(t, name)) |
| |
|
| | def _end_cache(t: BaseTransform): |
| | |
| | if not hasattr(t, '_methods_with_randomness'): |
| | return |
| |
|
| | |
| | delattr(t, '_cache_enabled') |
| | if hasattr(t, '_cache'): |
| | delattr(t, '_cache') |
| |
|
| | |
| | if hasattr(t, '_methods_with_randomness'): |
| | for name in getattr(t, '_methods_with_randomness'): |
| | key = f'{id(t)}.{name}' |
| | setattr(t, name, key2method[key]) |
| |
|
| | key_transform = f'{id(t)}.transform' |
| | setattr(t, 'transform', key2method[key_transform]) |
| |
|
| | def _apply(t: Union[BaseTransform, Iterable], |
| | func: Callable[[BaseTransform], None]): |
| | if isinstance(t, BaseTransform): |
| | func(t) |
| | if isinstance(t, Iterable): |
| | for _t in t: |
| | _apply(_t, func) |
| |
|
| | try: |
| | _apply(transforms, _start_cache) |
| | yield |
| | finally: |
| | _apply(transforms, _end_cache) |
| |
|