| | from typing import List |
| | from functools import wraps |
| | class ModuleRegistry: |
| |
|
| | def __init__(self): |
| | self.module_dict = {} |
| | |
| | def register_module(self, cls_name: str, cls): |
| | if cls_name in self.module_dict: |
| | raise ValueError(f"Found duplicate module: `{cls_name}`!") |
| | self.module_dict[cls_name] = cls |
| | |
| | def get_module(self, cls_name: str): |
| | if cls_name not in self.module_dict: |
| | raise ValueError(f"module `{cls_name}` not Found!") |
| | return self.module_dict[cls_name] |
| | |
| | def has_module(self, cls_name: str) -> bool: |
| | return cls_name in self.module_dict |
| |
|
| | MODULE_REGISTRY = ModuleRegistry() |
| |
|
| | def register_module(cls_name, cls): |
| | MODULE_REGISTRY.register_module(cls_name=cls_name, cls=cls) |
| |
|
| |
|
| | class ModelRegistry: |
| |
|
| | def __init__(self): |
| | |
| | self.models = {} |
| | self.model_configs = {} |
| | |
| | def register(self, key: str, model_cls, config_cls): |
| | if key in self.models: |
| | raise ValueError(f"model name '{key}' is already registered!") |
| | self.models[key] = model_cls |
| | self.model_configs[key] = config_cls |
| | |
| | def key_error_message(self, key: str): |
| | error_message = f"""`{key}` is not a registered model name. Currently availabel model names: {self.get_model_names()}. If `{key}` is a customized model, you should use @register_model({key}) to register the model.""" |
| | return error_message |
| | |
| | def get_model(self, key: str): |
| | model = self.models.get(key, None) |
| | if model is None: |
| | raise KeyError(self.key_error_message(key)) |
| | return model |
| | |
| | def get_model_config(self, key: str): |
| | config = self.model_configs.get(key, None) |
| | if config is None: |
| | raise KeyError(self.key_error_message(key)) |
| | return config |
| |
|
| | def get_model_names(self): |
| | return list(self.models.keys()) |
| |
|
| |
|
| | MODEL_REGISTRY = ModelRegistry() |
| |
|
| | def register_model(config_cls, alias: List[str]=None): |
| |
|
| | def decorator(cls): |
| | class_name = cls.__name__ |
| | MODEL_REGISTRY.register(class_name, cls, config_cls) |
| | if alias is not None: |
| | for alia in alias: |
| | MODEL_REGISTRY.register(alia, cls, config_cls) |
| | return cls |
| | |
| | return decorator |
| |
|
| | class ParseFunctionRegistry: |
| | |
| | def __init__(self): |
| | self.functions = {} |
| | |
| | def register(self, func_name: str, func): |
| | """Register a function with a given name. |
| | |
| | Args: |
| | func_name: The name to register the function under |
| | func (Callable): The function to register |
| | |
| | Raises: |
| | ValueError: If a function with the same name is already registered |
| | """ |
| | if func_name in self.functions: |
| | raise ValueError(f"Function name '{func_name}' is already registered!") |
| | self.functions[func_name] = func |
| | |
| | def get_function(self, func_name: str) -> callable: |
| | """Get a registered function by name. |
| | |
| | Args: |
| | func_name: The name of the function to retrieve |
| | |
| | Returns: |
| | Callable: The registered function |
| | |
| | Raises: |
| | KeyError: If no function with the given name is registered |
| | """ |
| | if func_name not in self.functions: |
| | available_funcs = list(self.functions.keys()) |
| | raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}") |
| | return self.functions[func_name] |
| | |
| | def has_function(self, func_name: str) -> bool: |
| | """Check if a function name is registered. |
| | |
| | Args: |
| | func_name: The name to check |
| | |
| | Returns: |
| | True if the function name is registered, False otherwise |
| | """ |
| | return func_name in self.functions |
| |
|
| |
|
| | PARSE_FUNCTION_REGISTRY = ParseFunctionRegistry() |
| |
|
| |
|
| | def register_parse_function(func): |
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | return func(*args, **kwargs) |
| | PARSE_FUNCTION_REGISTRY.register(func.__name__, wrapper) |
| | return wrapper |
| |
|
| |
|
| | class ActionFunctionRegistry: |
| | |
| | def __init__(self): |
| | self.functions = {} |
| | |
| | def register(self, func_name: str, func): |
| | """Register a function with a given name. |
| | |
| | Args: |
| | func_name: The name to register the function under |
| | func (Callable): The function to register |
| | |
| | Raises: |
| | ValueError: If a function with the same name is already registered |
| | """ |
| | if func_name in self.functions: |
| | raise ValueError(f"Function name '{func_name}' is already registered!") |
| | self.functions[func_name] = func |
| | |
| | def get_function(self, func_name: str) -> callable: |
| | """Get a registered function by name. |
| | |
| | Args: |
| | func_name: The name of the function to retrieve |
| | |
| | Returns: |
| | Callable: The registered function |
| | |
| | Raises: |
| | KeyError: If no function with the given name is registered |
| | """ |
| | if func_name not in self.functions: |
| | available_funcs = list(self.functions.keys()) |
| | raise KeyError(f"Function '{func_name}' not found! Available functions: {available_funcs}") |
| | return self.functions[func_name] |
| | |
| | def has_function(self, func_name: str) -> bool: |
| | """Check if a function name is registered. |
| | |
| | Args: |
| | func_name: The name to check |
| | |
| | Returns: |
| | True if the function name is registered, False otherwise |
| | """ |
| | return func_name in self.functions |
| |
|
| |
|
| | ACTION_FUNCTION_REGISTRY = ActionFunctionRegistry() |
| |
|
| |
|
| | def register_action_function(func): |
| | """Register a function for ActionAgent serialization. |
| | |
| | Args: |
| | func (Callable): The function to register |
| | |
| | Returns: |
| | Callable: The original function (for decorator usage) |
| | """ |
| | @wraps(func) |
| | def wrapper(*args, **kwargs): |
| | return func(*args, **kwargs) |
| | ACTION_FUNCTION_REGISTRY.register(func.__name__, wrapper) |
| | return wrapper |
| | |