File size: 5,214 Bytes
5439a54
e6ef7ae
 
 
 
 
6994a32
5439a54
630839b
be022c2
6bda65c
 
d0097ff
 
 
 
e9e3777
d0097ff
80baa7f
 
5439a54
e9e3777
5439a54
ccef61e
 
e5347ad
e976342
e6ef7ae
5770b2e
e6ef7ae
 
 
6994a32
ccef61e
 
 
b80d473
6994a32
 
 
d0097ff
 
6994a32
80baa7f
6994a32
 
e6ef7ae
 
 
 
 
 
 
 
 
 
 
5770b2e
e6ef7ae
80baa7f
e6ef7ae
 
5770b2e
e6ef7ae
 
 
 
 
 
 
 
 
 
9f0d515
e6ef7ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88c10b9
e6ef7ae
 
 
80baa7f
e6ef7ae
 
 
 
ccef61e
 
 
 
 
 
e6ef7ae
e5347ad
5439a54
ccef61e
e6ef7ae
 
 
 
 
 
 
 
80baa7f
e6ef7ae
be022c2
a174711
5465cf2
be022c2
e6ef7ae
be022c2
44a8302
be022c2
e6ef7ae
 
 
630839b
6bda65c
 
 
5770b2e
 
 
 
6bda65c
 
 
 
e6ef7ae
6bda65c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
import logging
import requests
from airflow import DAG
from airflow.decorators import task
import json
import time
from dotenv import load_dotenv
from common import check_db_connection, default_args
from datetime import datetime, timedelta
from airflow.utils.task_group import TaskGroup
from airflow.operators.bash import BashOperator
from http.client import (
    TOO_MANY_REQUESTS, # 429
    INTERNAL_SERVER_ERROR # 500
)
from pathlib import Path

logger = logging.getLogger(__name__)

# Load environment variables from .env file
load_dotenv(Path(__file__).parent.parent / ".env")

DEFAULT_TRANSACTION_PRODUCER_ENDPOINT = "https://charlestng-real-time-fraud-detection.hf.space/current-transactions"
TRANSACTION_PRODUCER_ENDPOINT = os.getenv("TRANSACTION_PRODUCER_ENDPOINT", DEFAULT_TRANSACTION_PRODUCER_ENDPOINT)
TRANSACTION_CONSUMER_ENDPOINT = os.getenv("TRANSACTION_CONSUMER_ENDPOINT")

@task(task_id="pull_transaction")
def _pull_transaction(ti, prefix: str = ''):
    """
    Pulls a new transaction from the fraud detection service and pushes it to XCom.
    """
    def get_current_transaction():
        if TRANSACTION_PRODUCER_API_KEY := os.getenv("TRANSACTION_PRODUCER_API_KEY"):
            headers = {'Authorization': TRANSACTION_PRODUCER_API_KEY}
        
        return requests.get(url=TRANSACTION_PRODUCER_ENDPOINT, headers=headers)
    
    response = get_current_transaction()

    # If status code is 429 or 500, wait for a few seconds and retry
    if response.status_code in [TOO_MANY_REQUESTS, INTERNAL_SERVER_ERROR]:
        waiting_time = 15
        logger.warning(f"Rate limit exceeded. Retrying in {waiting_time} seconds...")
        time.sleep(waiting_time)
        response = get_current_transaction()

    # Check response status code
    response.raise_for_status()
    
    # Load the JSON data
    str_data = response.json()
    data = json.loads(str_data)

    transaction_dict = {key: value for key, value in zip(data['columns'], data['data'][0])}

    # Push the transaction dictionary to XCom
    ti.xcom_push(key=f"{prefix}_transaction_dict", value=transaction_dict)
    
    logger.info(f"Fetched data: {transaction_dict}")

@task(task_id="push_transaction")
def _push_transaction(ti, prefix: str = ''):
    """
    Pushes the transaction data to the fraud detection pipeline.
    This function is called after pulling the transaction data.
    It maps the transaction data to the required parameters for the API call.
    """
    params_mapping = {
        'transaction_number': 'trans_num',
        'transaction_amount': 'amt',
        'transaction_timestamp': 'current_time',
        'transaction_category': 'category',
        'transaction_is_real_fraud': 'is_fraud',
        'customer_credit_card_number': 'cc_num',
        'customer_first_name': 'first',
        'customer_last_name': 'last',
        'customer_gender': 'gender',
        'merchant_name': 'merchant',
        'merchant_latitude': 'merch_lat',
        'merchant_longitude': 'merch_long',
        'customer_latitude': 'lat',
        'customer_longitude': 'long',
        'customer_city': 'city',
        'customer_state': 'state',
        'customer_zip': 'zip',
        'customer_city_population': 'city_pop',
        'customer_job': 'job',
        'customer_dob': 'dob',
        'is_fraud': 'is_fraud',
    }

    # Pull the transaction dictionary from XCom
    transaction_dict = ti.xcom_pull(f"transactions.{prefix}.pull_transaction", key=f"{prefix}_transaction_dict")

    # Check if the transaction dictionary is empty
    if not transaction_dict:
        logger.error("No transaction data found.")
        return
    
    # Call the fraud detection pipeline with the transaction dictionary
    data = {key: transaction_dict[value] for key, value in params_mapping.items()}
    
    headers = {'Content-Type': 'application/json'}
    
    if TRANSACTION_CONSUMER_API_KEY := os.getenv("TRANSACTION_CONSUMER_API_KEY"):
        headers['Authorization'] = TRANSACTION_CONSUMER_API_KEY
    
    api_response = requests.post(
        url=TRANSACTION_CONSUMER_ENDPOINT,
        data=json.dumps(data),
        headers=headers,
    )

    # Check response status code
    api_response.raise_for_status()

    # Load the JSON data
    data = api_response.json()

    logger.info(f"Fraud detection response: {data}")

# Copy default_args from the original code
dag_args = default_args.copy()
dag_args['start_date'] = datetime.now() - timedelta(minutes=2)

with DAG(dag_id="process_new_transaction",
         default_args=dag_args,
         max_active_runs=1,
         schedule_interval="*/1 * * * *") as dag:
    """
    DAG to fetch a new transaction and call the fraud detection pipeline
    """
    check_db = check_db_connection()

    with TaskGroup(group_id='transactions') as all_tasks:
        for i in range(1, 4):
            group = f'transaction_{i}'
            with TaskGroup(group_id=group) as transaction_group:
                pull = _pull_transaction(prefix=group)
                push = _push_transaction(prefix=group)
                
                pull >> push

    end_dag = BashOperator(task_id="end_dag", bash_command="echo 'End!'")
    
    check_db >> all_tasks >> end_dag