| """
|
| Utilities for safely accessing configuration values regardless of whether they're stored
|
| in dictionaries or objects.
|
| """
|
|
|
| def safe_get_config(config_obj, key, default=None):
|
| """
|
| Safely get a configuration value regardless of whether the config object is a dictionary or object.
|
|
|
| Args:
|
| config_obj: Configuration object (dict or object with attributes)
|
| key: Key or attribute name to access
|
| default: Default value to return if key/attribute doesn't exist
|
|
|
| Returns:
|
| The value of the key/attribute, or default if not found
|
| """
|
| if config_obj is None:
|
| return default
|
|
|
| if isinstance(config_obj, dict):
|
| return config_obj.get(key, default)
|
|
|
| return getattr(config_obj, key, default)
|
|
|
| def get_model_name(config_obj):
|
| """
|
| Gets model name from config with proper fallbacks.
|
|
|
| Args:
|
| config_obj: Configuration object
|
|
|
| Returns:
|
| Model name string
|
| """
|
| transformer_config = safe_get_config(config_obj, "TRANSFORMER_CONFIG", {})
|
| return safe_get_config(transformer_config, "MODEL_NAME", "bert-base-uncased")
|
|
|
| def get_embedding_model(config_obj):
|
| """
|
| Creates or retrieves an embedding model based on configuration.
|
|
|
| Args:
|
| config_obj: Configuration object
|
|
|
| Returns:
|
| Embedding model instance
|
| """
|
| from utils.transformer_utils import get_sentence_transformer
|
| model_name = get_model_name(config_obj)
|
| return get_sentence_transformer(model_name)
|
|
|
| """Utilities for configuration validation and fixes"""
|
| import os
|
| import logging
|
| from typing import Dict, Any
|
|
|
| logger = logging.getLogger(__name__)
|
|
|
| def validate_config(config: Any) -> None:
|
| """Validate and fix configuration objects to prevent errors"""
|
|
|
|
|
| if hasattr(config, 'TRANSFORMER_CONFIG'):
|
| tc = config.TRANSFORMER_CONFIG
|
|
|
|
|
| if isinstance(tc, dict):
|
|
|
| defaults = {
|
| "MAX_SEQ_LENGTH": 512,
|
| "MODEL_NAME": "bert-base-uncased",
|
| "NUM_LAYERS": 6,
|
| "EMBEDDING_DIM": 768,
|
| "NUM_HEADS": 12,
|
| "HIDDEN_DIM": 768,
|
| "DROPOUT": 0.1,
|
| "POOLING_MODE": "mean"
|
| }
|
|
|
|
|
| for key, value in defaults.items():
|
| if key not in tc:
|
| tc[key] = value
|
| logger.info(f"Added default {key}={value} to TRANSFORMER_CONFIG")
|
|
|
|
|
| data_dir = getattr(config, "DATA_DIR", os.environ.get("TLM_DATA_DIR", "/tmp/tlm_data"))
|
| os.makedirs(data_dir, exist_ok=True)
|
|
|
|
|
| model_dir = getattr(config, "MODEL_DIR", os.path.join(data_dir, "models"))
|
| os.makedirs(model_dir, exist_ok=True)
|
|
|
| logger.info("Configuration validated and fixed")
|
|
|