import uuid from collections import OrderedDict from functools import wraps from typing import Callable, Dict, List, Optional, Type import torch.nn as nn from torch.distributed._composable_state import _State def generate_state_key(string="__composable_api_state_key"): return f"{string}_{str(uuid.uuid4())}" STATE_KEY = generate_state_key() REGISTRY_KEY = generate_state_key() # TODO: we can add additional info to RegistryItem to share across APIs. E.g., # we can add args and kwargs here, and then we can detect whether fully_shard # is combined with reentrant activation checkpointing and error out with a clear # message. class RegistryItem: pass def contract(state_cls: Type[_State] = _State): r""" Decorate a function as a composable distributed API, where the first argument of the function must be an :class:`nn.Module` instance. The decorator verifies that the wrapped function does not modify parameter, buffer or sub-module fully-qualified names (FQN). When a function ``func`` is decorated by ``@contract()``, a ``.state(module: nn.Module)`` method will be installed to the decorated function. Then you can retrieve and modify the state on a module by calling ``func.state(module)``. Example:: >>> # xdoctest: +SKIP >>> import torch.nn as nn >>> >>> class MyModel(nn.Module): >>> def __init__(self): >>> super().__init__() >>> self.l1 = nn.Linear(10, 10) >>> self.l2 = nn.Linear(10, 10) >>> >>> def forward(self, x): >>> return self.l2(self.l1(x)) >>> >>> @contract() >>> def my_feature(module: nn.Module) -> nn.Module: >>> my_feature.state(module).some_state = "any value" >>> return module >>> >>> model = MyModel() >>> my_feature(model.l1) >>> assert my_feature.state(model.l1).some_state == "any value" >>> my_feature(model.l2) >>> model(torch.randn(2, 10)).sum().backward() """ # wraps will make functions decorated with contract() pickleable - needed for integration with torch.package @wraps(state_cls) def inner(func): @wraps(func) def wrapper(module: nn.Module, *args, **kwargs) -> Optional[nn.Module]: # get existing global states default_all_state: Dict[Callable, _State] = OrderedDict() all_state: Dict[Callable, _State] = module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, default_all_state ) assert isinstance( all_state, dict ), "Distributed composable API states corrupted" # get global registry default_registry: Dict[str, RegistryItem] = OrderedDict() registry: Dict[str, RegistryItem] = module.__dict__.setdefault( # type: ignore[call-overload] REGISTRY_KEY, default_registry ) assert isinstance( registry, dict ), "Distributed composable API registry corrupted" # make sure the API func has not been applied to the input module yet. assert func not in all_state and func.__name__ not in registry, ( "Each distinct composable distributed API can only be applied to a " f"module once. {func.__name__} has already been applied to the " f"following module.\n{module}" ) # install states specific to the wrapped ``func`` all_state.setdefault(func, state_cls()) # register ``func`` in the global registry by name registry.setdefault(func.__name__, RegistryItem()) orig_named_params = OrderedDict(module.named_parameters()) orig_named_buffers = OrderedDict( module.named_buffers(remove_duplicate=False) ) orig_named_modules = OrderedDict( module.named_modules(remove_duplicate=False) ) updated = func(module, *args, **kwargs) if updated is None: updated = module new_named_params = OrderedDict(updated.named_parameters()) new_named_buffers = OrderedDict( updated.named_buffers(remove_duplicate=False) ) new_named_modules = OrderedDict( updated.named_modules(remove_duplicate=False) ) assert isinstance(updated, nn.Module), ( "Output of composable distributed APIs must be either None or " f"nn.Module, but got {type(updated)}" ) def check_fqn(orig_fqns: List[str], new_fqns: List[str], check_key: str): if orig_fqns == new_fqns: return orig_fqn_set, new_fqn_set = set(orig_fqns), set(new_fqns) orig_only = orig_fqn_set - new_fqn_set new_only = new_fqn_set - orig_fqn_set if len(orig_only) or len(new_only): raise RuntimeError( f"{check_key}" "Composable distributed API implementations cannot modify " "FQNs.\n" f"Only in original FQNs: {orig_only},\n" f"Only in new FQNs: {new_only}" ) else: raise RuntimeError( f"{check_key}" "Composable distributed API implementations cannot modify " "the order of FQNs.\n" f"Original FQNs: {orig_only}\n" f"New FQNs: {new_only}" ) check_fqn( list(orig_named_params.keys()), list(new_named_params.keys()), "Check parameters, ", ) check_fqn( list(orig_named_buffers.keys()), list(new_named_buffers.keys()), "Check buffer, ", ) check_fqn( list(orig_named_modules.keys()), list(new_named_modules.keys()), "Check modules, ", ) # TODO: a stricter verification should also reject changing module # types and monkey-patching forward() method implementations. # TODO: verify that installed distributed paradigms are compatible with # each other. return updated def get_state(module: nn.Module) -> Optional[_State]: return module.__dict__.setdefault( # type: ignore[call-overload] STATE_KEY, {}, # TODO(@yhcharles): this is a temporary fix, need a better way ).get( func ) # type: ignore[call-overload] wrapper.state = get_state # type: ignore[attr-defined] return wrapper return inner def _get_registry(module: nn.Module) -> Optional[Dict[str, RegistryItem]]: r""" Get an ``OrderedDict`` of composable APIs that have been applied to the ``module``, indexed by the API name. If no API has been applied, then this returns ``None``. """ return getattr(module, REGISTRY_KEY, None)