import yaml | |
import warnings | |
from pdf_extract_kit.registry.registry import TASK_REGISTRY | |
def load_config(config_path): | |
if config_path is None: | |
warnings.warn("Configuration path is None. Please provide a valid configuration file path.") | |
return None | |
with open(config_path, 'r') as file: | |
config = yaml.safe_load(file) | |
return config | |
def initialize_tasks_and_models(config): | |
task_instances = {} | |
for task_name, task_config in config['tasks'].items(): | |
# 从 config 中读取 type 字段作为注册器查找键 | |
task_type = task_config.get("type") | |
if task_type is None: | |
raise ValueError(f"Task '{task_name}' missing required 'type' field in config.") | |
# 查找注册的类 | |
TaskClass = TASK_REGISTRY.get(task_type) | |
if TaskClass is None: | |
raise ValueError(f"Task type '{task_type}' not found in TASK_REGISTRY.") | |
# 其他字段作为初始化参数 | |
kwargs = {k: v for k, v in task_config.items() if k != "type"} | |
task_instance = TaskClass(**kwargs) | |
task_instances[task_name] = task_instance | |
return task_instances | |