Spaces:
Running
Running
File size: 5,214 Bytes
5439a54 e6ef7ae 6994a32 5439a54 630839b be022c2 6bda65c d0097ff e9e3777 d0097ff 80baa7f 5439a54 e9e3777 5439a54 ccef61e e5347ad e976342 e6ef7ae 5770b2e e6ef7ae 6994a32 ccef61e b80d473 6994a32 d0097ff 6994a32 80baa7f 6994a32 e6ef7ae 5770b2e e6ef7ae 80baa7f e6ef7ae 5770b2e e6ef7ae 9f0d515 e6ef7ae 88c10b9 e6ef7ae 80baa7f e6ef7ae ccef61e e6ef7ae e5347ad 5439a54 ccef61e e6ef7ae 80baa7f e6ef7ae be022c2 a174711 5465cf2 be022c2 e6ef7ae be022c2 44a8302 be022c2 e6ef7ae 630839b 6bda65c 5770b2e 6bda65c e6ef7ae 6bda65c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import os
import logging
import requests
from airflow import DAG
from airflow.decorators import task
import json
import time
from dotenv import load_dotenv
from common import check_db_connection, default_args
from datetime import datetime, timedelta
from airflow.utils.task_group import TaskGroup
from airflow.operators.bash import BashOperator
from http.client import (
TOO_MANY_REQUESTS, # 429
INTERNAL_SERVER_ERROR # 500
)
from pathlib import Path
logger = logging.getLogger(__name__)
# Load environment variables from .env file
load_dotenv(Path(__file__).parent.parent / ".env")
DEFAULT_TRANSACTION_PRODUCER_ENDPOINT = "https://charlestng-real-time-fraud-detection.hf.space/current-transactions"
TRANSACTION_PRODUCER_ENDPOINT = os.getenv("TRANSACTION_PRODUCER_ENDPOINT", DEFAULT_TRANSACTION_PRODUCER_ENDPOINT)
TRANSACTION_CONSUMER_ENDPOINT = os.getenv("TRANSACTION_CONSUMER_ENDPOINT")
@task(task_id="pull_transaction")
def _pull_transaction(ti, prefix: str = ''):
"""
Pulls a new transaction from the fraud detection service and pushes it to XCom.
"""
def get_current_transaction():
if TRANSACTION_PRODUCER_API_KEY := os.getenv("TRANSACTION_PRODUCER_API_KEY"):
headers = {'Authorization': TRANSACTION_PRODUCER_API_KEY}
return requests.get(url=TRANSACTION_PRODUCER_ENDPOINT, headers=headers)
response = get_current_transaction()
# If status code is 429 or 500, wait for a few seconds and retry
if response.status_code in [TOO_MANY_REQUESTS, INTERNAL_SERVER_ERROR]:
waiting_time = 15
logger.warning(f"Rate limit exceeded. Retrying in {waiting_time} seconds...")
time.sleep(waiting_time)
response = get_current_transaction()
# Check response status code
response.raise_for_status()
# Load the JSON data
str_data = response.json()
data = json.loads(str_data)
transaction_dict = {key: value for key, value in zip(data['columns'], data['data'][0])}
# Push the transaction dictionary to XCom
ti.xcom_push(key=f"{prefix}_transaction_dict", value=transaction_dict)
logger.info(f"Fetched data: {transaction_dict}")
@task(task_id="push_transaction")
def _push_transaction(ti, prefix: str = ''):
"""
Pushes the transaction data to the fraud detection pipeline.
This function is called after pulling the transaction data.
It maps the transaction data to the required parameters for the API call.
"""
params_mapping = {
'transaction_number': 'trans_num',
'transaction_amount': 'amt',
'transaction_timestamp': 'current_time',
'transaction_category': 'category',
'transaction_is_real_fraud': 'is_fraud',
'customer_credit_card_number': 'cc_num',
'customer_first_name': 'first',
'customer_last_name': 'last',
'customer_gender': 'gender',
'merchant_name': 'merchant',
'merchant_latitude': 'merch_lat',
'merchant_longitude': 'merch_long',
'customer_latitude': 'lat',
'customer_longitude': 'long',
'customer_city': 'city',
'customer_state': 'state',
'customer_zip': 'zip',
'customer_city_population': 'city_pop',
'customer_job': 'job',
'customer_dob': 'dob',
'is_fraud': 'is_fraud',
}
# Pull the transaction dictionary from XCom
transaction_dict = ti.xcom_pull(f"transactions.{prefix}.pull_transaction", key=f"{prefix}_transaction_dict")
# Check if the transaction dictionary is empty
if not transaction_dict:
logger.error("No transaction data found.")
return
# Call the fraud detection pipeline with the transaction dictionary
data = {key: transaction_dict[value] for key, value in params_mapping.items()}
headers = {'Content-Type': 'application/json'}
if TRANSACTION_CONSUMER_API_KEY := os.getenv("TRANSACTION_CONSUMER_API_KEY"):
headers['Authorization'] = TRANSACTION_CONSUMER_API_KEY
api_response = requests.post(
url=TRANSACTION_CONSUMER_ENDPOINT,
data=json.dumps(data),
headers=headers,
)
# Check response status code
api_response.raise_for_status()
# Load the JSON data
data = api_response.json()
logger.info(f"Fraud detection response: {data}")
# Copy default_args from the original code
dag_args = default_args.copy()
dag_args['start_date'] = datetime.now() - timedelta(minutes=2)
with DAG(dag_id="process_new_transaction",
default_args=dag_args,
max_active_runs=1,
schedule_interval="*/1 * * * *") as dag:
"""
DAG to fetch a new transaction and call the fraud detection pipeline
"""
check_db = check_db_connection()
with TaskGroup(group_id='transactions') as all_tasks:
for i in range(1, 4):
group = f'transaction_{i}'
with TaskGroup(group_id=group) as transaction_group:
pull = _pull_transaction(prefix=group)
push = _push_transaction(prefix=group)
pull >> push
end_dag = BashOperator(task_id="end_dag", bash_command="echo 'End!'")
check_db >> all_tasks >> end_dag
|