# Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. # # NVIDIA CORPORATION and its licensors retain all intellectual property # and proprietary rights in and to this software, related documentation # and any modifications thereto. Any use, reproduction, disclosure or # distribution of this software and related documentation without an express # license agreement from NVIDIA CORPORATION is strictly prohibited. from argparse import Namespace from typing import Dict, Any import torch from .adaptor_generic import GenericAdaptor, AdaptorBase dict_t = Dict[str, Any] state_t = Dict[str, torch.Tensor] class AdaptorRegistry: def __init__(self): self._registry = {} def register_adaptor(self, name): def decorator(factory_function): if name in self._registry: raise ValueError(f"Model '{name}' already registered") self._registry[name] = factory_function return factory_function return decorator def create_adaptor(self, name, main_config: Namespace, adaptor_config: dict_t, state: state_t) -> AdaptorBase: if name not in self._registry: return GenericAdaptor(main_config, adaptor_config, state) return self._registry[name](main_config, adaptor_config, state) # Creating an instance of the registry adaptor_registry = AdaptorRegistry()