# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. import torch.nn as nn from mmengine.registry import MODEL_WRAPPERS, Registry def is_model_wrapper(model: nn.Module, registry: Registry = MODEL_WRAPPERS): """Check if a module is a model wrapper. The following 4 model in MMEngine (and their subclasses) are regarded as model wrappers: DataParallel, DistributedDataParallel, MMDataParallel, MMDistributedDataParallel. You may add you own model wrapper by registering it to ``mmengine.registry.MODEL_WRAPPERS``. Args: model (nn.Module): The model to be checked. registry (Registry): The parent registry to search for model wrappers. Returns: bool: True if the input model is a model wrapper. """ module_wrappers = tuple(registry.module_dict.values()) if isinstance(model, module_wrappers): return True if not registry.children: return False return any( is_model_wrapper(model, child) for child in registry.children.values())