|
""" |
|
Airflow API client for Mezura. |
|
""" |
|
import requests |
|
import logging |
|
import json |
|
import os |
|
from typing import Dict, Any, Optional |
|
import re |
|
|
|
from api.config import get_api_config, get_api_config_for_type, get_airflow_config |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
class AirflowClient: |
|
""" |
|
Client for interacting with Airflow API |
|
""" |
|
|
|
def __init__(self): |
|
"""Initialize the Airflow API client""" |
|
|
|
airflow_config = get_airflow_config() |
|
|
|
|
|
self.airflow_base_url = airflow_config.get("base_url") |
|
if not self.airflow_base_url: |
|
raise ValueError("Airflow base URL not found in configuration") |
|
|
|
|
|
auth = airflow_config.get("auth", {}) |
|
self.username = auth.get("username") |
|
self.password = auth.get("password") |
|
|
|
|
|
if not self.username or not self.password: |
|
error_msg = "Airflow authentication credentials not found in configuration" |
|
|
|
|
|
if auth.get("use_env", False): |
|
username_env = auth.get("env_username", "MEZURA_API_USERNAME") |
|
password_env = auth.get("env_password", "MEZURA_API_PASSWORD") |
|
|
|
username_exists = os.environ.get(username_env) is not None |
|
password_exists = os.environ.get(password_env) is not None |
|
|
|
if not username_exists or not password_exists: |
|
missing_vars = [] |
|
if not username_exists: |
|
missing_vars.append(username_env) |
|
if not password_exists: |
|
missing_vars.append(password_env) |
|
|
|
error_msg = f"Required environment variables not set: {', '.join(missing_vars)}" |
|
|
|
raise ValueError(error_msg) |
|
|
|
|
|
self.timeout = airflow_config.get("timeout", 30) |
|
self.retry_attempts = airflow_config.get("retry_attempts", 3) |
|
|
|
logger.info(f"Airflow API client initialized with base URL: {self.airflow_base_url}") |
|
|
|
|
|
|
|
def send_dag_request(self, dag_id: str, conf: Dict[str, Any]) -> Dict[str, Any]: |
|
""" |
|
Sends a request to start a DAG run. |
|
|
|
Args: |
|
dag_id: The ID of the DAG to run |
|
conf: The configuration for the DAG run |
|
|
|
Returns: |
|
Dict[str, Any]: DAG run response |
|
""" |
|
try: |
|
|
|
airflow_endpoint = f"{self.airflow_base_url}/api/v1/dags/{dag_id}/dagRuns" |
|
|
|
|
|
conf_copy = conf.copy() if conf else {} |
|
|
|
|
|
|
|
if "username" in conf_copy: |
|
|
|
|
|
if conf_copy["username"] is None: |
|
logger.error("Username is None but required for API request") |
|
raise ValueError("Username is required for benchmark submission") |
|
|
|
elif not isinstance(conf_copy["username"], (str, int, float, bool)): |
|
conf_copy["username"] = str(conf_copy["username"]) |
|
|
|
|
|
username_str = str(conf_copy["username"]) |
|
logout_pattern = re.compile(r'Logout \(([^)]+)\)') |
|
match = logout_pattern.search(username_str) |
|
if match: |
|
conf_copy["username"] = match.group(1) |
|
|
|
elif '(' in username_str and ')' in username_str: |
|
try: |
|
start = username_str.rindex('(') + 1 |
|
end = username_str.find(')', start) |
|
if start < end: |
|
extracted = username_str[start:end].strip() |
|
if extracted: |
|
conf_copy["username"] = extracted |
|
except: |
|
pass |
|
else: |
|
|
|
logger.error("Username field missing from request configuration") |
|
raise ValueError("Username is required for benchmark submission") |
|
|
|
|
|
payload = {"conf": conf_copy} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
response = requests.post( |
|
airflow_endpoint, |
|
json=payload, |
|
auth=(self.username, self.password), |
|
timeout=self.timeout, |
|
headers={ |
|
'Content-Type': 'application/json', |
|
'Accept': 'application/json' |
|
} |
|
) |
|
|
|
|
|
logger.info(f"Response status code: {response.status_code}") |
|
|
|
|
|
if response.status_code in (200, 201): |
|
try: |
|
data = response.json() |
|
logger.info(f"Response data: {json.dumps(data)}") |
|
|
|
run_id = data.get("dag_run_id", "unknown") |
|
logger.info(f"DAG run triggered: {run_id}") |
|
|
|
return { |
|
"run_id": run_id, |
|
"status": "submitted", |
|
"dag_id": dag_id |
|
} |
|
except Exception as e: |
|
logger.error(f"Error parsing response: {e}") |
|
return { |
|
"error": f"Error parsing response: {str(e)}", |
|
"status": "error", |
|
"dag_id": dag_id |
|
} |
|
else: |
|
error_msg = f"API Error: {response.status_code}, {response.text}" |
|
logger.error(error_msg) |
|
return { |
|
"error": error_msg, |
|
"status": "error", |
|
"dag_id": dag_id |
|
} |
|
|
|
except Exception as e: |
|
error_msg = f"Request failed: {str(e)}" |
|
logger.error(error_msg) |
|
return { |
|
"error": error_msg, |
|
"status": "error", |
|
"dag_id": dag_id |
|
} |
|
|
|
def send_status_request(self, dag_id: str, run_id: str) -> Dict[str, Any]: |
|
""" |
|
Sends a status request to check the status of a DAG run. |
|
|
|
Args: |
|
dag_id: The ID of the DAG |
|
run_id: The DAG run ID returned by the send_dag_request method |
|
|
|
Returns: |
|
Dict[str, Any]: Status information |
|
""" |
|
try: |
|
|
|
status_url = f"{self.airflow_base_url}/api/v1/dags/{dag_id}/dagRuns/{run_id}" |
|
|
|
|
|
logger.info(f"Checking status for DAG run: {run_id}, URL: {status_url}") |
|
|
|
|
|
response = requests.get( |
|
status_url, |
|
auth=(self.username, self.password), |
|
timeout=self.timeout, |
|
headers={'Accept': 'application/json'} |
|
) |
|
|
|
|
|
logger.info(f"Status response code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
try: |
|
data = response.json() |
|
state = data.get("state", "unknown") |
|
|
|
|
|
status_mapping = { |
|
"running": "running", |
|
"success": "completed", |
|
"failed": "failed", |
|
"queued": "pending" |
|
} |
|
|
|
status_info = { |
|
"status": status_mapping.get(state, "unknown"), |
|
"progress": 100 if state == "success" else 0, |
|
"current_step": state, |
|
"error": None if state != "failed" else "DAG execution failed", |
|
"run_id": run_id, |
|
"dag_id": dag_id |
|
} |
|
|
|
logger.info(f"DAG run status: {state}") |
|
|
|
return status_info |
|
|
|
except Exception as e: |
|
error_msg = f"Error parsing status response: {str(e)}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"error": error_msg, |
|
"run_id": run_id, |
|
"dag_id": dag_id |
|
} |
|
else: |
|
error_msg = f"Status API Error: {response.status_code}, {response.text}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"error": error_msg, |
|
"run_id": run_id, |
|
"dag_id": dag_id |
|
} |
|
|
|
except Exception as e: |
|
error_msg = f"Status request failed: {str(e)}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"error": error_msg, |
|
"run_id": run_id, |
|
"dag_id": dag_id |
|
} |
|
|
|
def send_logs_request(self, dag_id: str, run_id: str, task_id: str = "process_results") -> Dict[str, Any]: |
|
""" |
|
Sends a request to get the logs of a DAG run. |
|
|
|
Args: |
|
dag_id: The ID of the DAG |
|
run_id: The DAG run ID |
|
task_id: The task ID to get logs for, defaults to process_results |
|
|
|
Returns: |
|
Dict[str, Any]: Log information |
|
""" |
|
try: |
|
|
|
logs_url = f"{self.airflow_base_url}/api/v1/dags/{dag_id}/dagRuns/{run_id}/taskInstances/{task_id}/logs" |
|
|
|
|
|
logger.info(f"Getting logs for DAG run ID: {run_id}, URL: {logs_url}") |
|
|
|
|
|
response = requests.get( |
|
logs_url, |
|
auth=(self.username, self.password), |
|
timeout=self.timeout, |
|
headers={'Accept': 'application/json'} |
|
) |
|
|
|
|
|
logger.info(f"Logs response code: {response.status_code}") |
|
|
|
if response.status_code == 200: |
|
return { |
|
"logs": response.text, |
|
"status": "success", |
|
"run_id": run_id, |
|
"dag_id": dag_id |
|
} |
|
else: |
|
error_msg = f"Logs API Error: {response.status_code}, {response.text}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"error": error_msg, |
|
"run_id": run_id, |
|
"dag_id": dag_id, |
|
"logs": "Failed to retrieve logs" |
|
} |
|
|
|
except Exception as e: |
|
error_msg = f"Logs request failed: {str(e)}" |
|
logger.error(error_msg) |
|
return { |
|
"status": "error", |
|
"error": error_msg, |
|
"run_id": run_id, |
|
"dag_id": dag_id, |
|
"logs": "Failed to retrieve logs due to an error" |
|
} |