Spaces:
Running
Running
import os | |
import logging | |
import requests | |
from airflow import DAG | |
from airflow.decorators import task | |
import json | |
import time | |
from dotenv import load_dotenv | |
from common import check_db_connection, default_args | |
from datetime import datetime, timedelta | |
from airflow.utils.task_group import TaskGroup | |
from airflow.operators.bash import BashOperator | |
from http.client import ( | |
TOO_MANY_REQUESTS, # 429 | |
INTERNAL_SERVER_ERROR # 500 | |
) | |
from pathlib import Path | |
logger = logging.getLogger(__name__) | |
# Load environment variables from .env file | |
load_dotenv(Path(__file__).parent.parent / ".env") | |
DEFAULT_TRANSACTION_PRODUCER_ENDPOINT = "https://charlestng-real-time-fraud-detection.hf.space/current-transactions" | |
TRANSACTION_PRODUCER_ENDPOINT = os.getenv("TRANSACTION_PRODUCER_ENDPOINT", DEFAULT_TRANSACTION_PRODUCER_ENDPOINT) | |
TRANSACTION_CONSUMER_ENDPOINT = os.getenv("TRANSACTION_CONSUMER_ENDPOINT") | |
def _pull_transaction(ti, prefix: str = ''): | |
""" | |
Pulls a new transaction from the fraud detection service and pushes it to XCom. | |
""" | |
def get_current_transaction(): | |
if TRANSACTION_PRODUCER_API_KEY := os.getenv("TRANSACTION_PRODUCER_API_KEY"): | |
headers = {'Authorization': TRANSACTION_PRODUCER_API_KEY} | |
return requests.get(url=TRANSACTION_PRODUCER_ENDPOINT, headers=headers) | |
response = get_current_transaction() | |
# If status code is 429 or 500, wait for a few seconds and retry | |
if response.status_code in [TOO_MANY_REQUESTS, INTERNAL_SERVER_ERROR]: | |
waiting_time = 15 | |
logger.warning(f"Rate limit exceeded. Retrying in {waiting_time} seconds...") | |
time.sleep(waiting_time) | |
response = get_current_transaction() | |
# Check response status code | |
response.raise_for_status() | |
# Load the JSON data | |
str_data = response.json() | |
data = json.loads(str_data) | |
transaction_dict = {key: value for key, value in zip(data['columns'], data['data'][0])} | |
# Push the transaction dictionary to XCom | |
ti.xcom_push(key=f"{prefix}_transaction_dict", value=transaction_dict) | |
logger.info(f"Fetched data: {transaction_dict}") | |
def _push_transaction(ti, prefix: str = ''): | |
""" | |
Pushes the transaction data to the fraud detection pipeline. | |
This function is called after pulling the transaction data. | |
It maps the transaction data to the required parameters for the API call. | |
""" | |
params_mapping = { | |
'transaction_number': 'trans_num', | |
'transaction_amount': 'amt', | |
'transaction_timestamp': 'current_time', | |
'transaction_category': 'category', | |
'transaction_is_real_fraud': 'is_fraud', | |
'customer_credit_card_number': 'cc_num', | |
'customer_first_name': 'first', | |
'customer_last_name': 'last', | |
'customer_gender': 'gender', | |
'merchant_name': 'merchant', | |
'merchant_latitude': 'merch_lat', | |
'merchant_longitude': 'merch_long', | |
'customer_latitude': 'lat', | |
'customer_longitude': 'long', | |
'customer_city': 'city', | |
'customer_state': 'state', | |
'customer_zip': 'zip', | |
'customer_city_population': 'city_pop', | |
'customer_job': 'job', | |
'customer_dob': 'dob', | |
'is_fraud': 'is_fraud', | |
} | |
# Pull the transaction dictionary from XCom | |
transaction_dict = ti.xcom_pull(f"transactions.{prefix}.pull_transaction", key=f"{prefix}_transaction_dict") | |
# Check if the transaction dictionary is empty | |
if not transaction_dict: | |
logger.error("No transaction data found.") | |
return | |
# Call the fraud detection pipeline with the transaction dictionary | |
data = {key: transaction_dict[value] for key, value in params_mapping.items()} | |
headers = {'Content-Type': 'application/json'} | |
if TRANSACTION_CONSUMER_API_KEY := os.getenv("TRANSACTION_CONSUMER_API_KEY"): | |
headers['Authorization'] = TRANSACTION_CONSUMER_API_KEY | |
api_response = requests.post( | |
url=TRANSACTION_CONSUMER_ENDPOINT, | |
data=json.dumps(data), | |
headers=headers, | |
) | |
# Check response status code | |
api_response.raise_for_status() | |
# Load the JSON data | |
data = api_response.json() | |
logger.info(f"Fraud detection response: {data}") | |
# Copy default_args from the original code | |
dag_args = default_args.copy() | |
dag_args['start_date'] = datetime.now() - timedelta(minutes=2) | |
with DAG(dag_id="process_new_transaction", | |
default_args=dag_args, | |
max_active_runs=1, | |
schedule_interval="*/1 * * * *") as dag: | |
""" | |
DAG to fetch a new transaction and call the fraud detection pipeline | |
""" | |
check_db = check_db_connection() | |
with TaskGroup(group_id='transactions') as all_tasks: | |
for i in range(1, 4): | |
group = f'transaction_{i}' | |
with TaskGroup(group_id=group) as transaction_group: | |
pull = _pull_transaction(prefix=group) | |
push = _push_transaction(prefix=group) | |
pull >> push | |
end_dag = BashOperator(task_id="end_dag", bash_command="echo 'End!'") | |
check_db >> all_tasks >> end_dag | |