File size: 5,535 Bytes
3232d64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
API configuration module.
"""
import os
import logging
import yaml
from typing import Dict, Any, Optional
import base64

logger = logging.getLogger(__name__)

# Cache config for performance
_config_cache = None

# Store the list of supported base models
_supported_base_models = []

def get_api_config() -> Dict[str, Any]:
    """
    Returns the API configuration.
    
    Returns:
        Dict[str, Any]: API configuration
    """
    global _config_cache
    
    # Use cache if available
    if _config_cache is not None:
        return _config_cache
    
    # Try to load config from file
    config_path = os.environ.get("MEZURA_CONFIG_PATH", "config/api_config.yaml")
    if os.path.exists(config_path):
        try:
            with open(config_path, 'r') as f:
                config = yaml.safe_load(f)
                _config_cache = config
                logger.info(f"Loaded API configuration from {config_path}")
                return config
        except Exception as e:
            logger.error(f"Error loading config from {config_path}: {e}")
            raise RuntimeError(f"Failed to load configuration from {config_path}: {e}")
    else:
        # If config file not found, raise error
        error_msg = f"Configuration file not found at {config_path}"
        logger.error(error_msg)
        raise FileNotFoundError(error_msg)

def get_api_config_for_type(evaluation_type: str) -> Dict[str, Any]:
    """
    Get API configuration for a specific evaluation type.
    
    Args:
        evaluation_type: Evaluation type (e.g., "evalmix")
        
    Returns:
        Dict[str, Any]: API configuration for the specified type
    """
    config = get_api_config()
    
    # Convert evaluation type to config key
    api_type = evaluation_type.lower().replace("-", "_")
    
    # Get API configuration
    if "apis" in config and api_type in config["apis"]:
        type_config = config["apis"][api_type]
        logger.debug(f"Using config for {evaluation_type}: {type_config}")
        return type_config
    
    # Get default configuration if not found
    if "default" in config:
        logger.warning(f"No configuration found for {evaluation_type}, using default")
        return config["default"]
    
    # If no default config either, return empty dict
    logger.warning(f"No configuration found for {evaluation_type} and no default config")
    return {}

def get_airflow_config() -> Dict[str, Any]:
    """
    Get Airflow API configuration.
    
    Returns:
        Dict[str, Any]: Airflow API configuration
    """
    config = get_api_config()
    
    # Get Airflow config from the loaded yaml configuration
    if "apis" in config and "airflow" in config["apis"]:
        airflow_config = config["apis"]["airflow"]
        logger.debug(f"Using Airflow config from YAML: {airflow_config}")
        
        # --- Load base_url from environment if available ---
        env_base_url = os.environ.get("AIRFLOW_URL")
        if env_base_url:
            airflow_config["base_url"] = env_base_url
            logger.info(f"Loaded Airflow base_url from environment variable AIRFLOW_URL: {env_base_url}")
        else:
            logger.info(f"Using Airflow base_url from YAML config: {airflow_config.get('base_url')}")
        # --- END base_url env logic ---
        
        # Check if credentials should be loaded from environment
        auth_config = airflow_config.get("auth", {})
        if auth_config.get("use_env", False):
            # Get environment variable names
            username_env = auth_config.get("env_username", "MEZURA_API_USERNAME")
            password_env = auth_config.get("env_password", "MEZURA_API_PASSWORD")
            
            # Log environment variable names
            logger.info(f"Looking for credentials in environment variables: {username_env}, {password_env}")
            
            # Check if environment variables are set
            username = os.environ.get(username_env)
            password = os.environ.get(password_env)
            
            # SECURITY: Commented out to prevent potential credential exposure
            # Directly access environment variables for better logging
            # all_env_vars = os.environ.keys()
            # logger.info(f"Available environment variables: {', '.join(all_env_vars)}")
            
            # Log results of environment variable check
            logger.info(f"Username variable '{username_env}' found: {username is not None}")
            logger.info(f"Password variable '{password_env}' found: {password is not None}")
            
            # Update auth config with credentials from environment
            if username and password:
                auth_config["username"] = username
                auth_config["password"] = password
                # Update the auth config in airflow_config
                airflow_config["auth"] = auth_config
        
        return airflow_config
    
    # If not found in config, log warning and return empty dict
    logger.warning("Airflow configuration not found in config file")
    return {} 

def update_base_model_list(models):
    """
    Updates the list of supported base models.
    
    Args:
        models (list): List of supported model names
    """
    global _supported_base_models
    _supported_base_models = models

def get_base_model_list():
    """
    Returns the current list of supported base models.
    
    Returns:
        list: List of supported model names
    """
    return _supported_base_models