Deepaksai1 commited on
Commit
bca978b
·
verified ·
1 Parent(s): e78ef66

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +137 -0
app.py ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, BertTokenizer, BertForSequenceClassification
4
+ import torch.nn.functional as F
5
+
6
+ # Load models
7
+ albert_model = AlbertForSequenceClassification.from_pretrained("Deepaksai1/albert-fraud-detector-v2").eval()
8
+ albert_tokenizer = AlbertTokenizer.from_pretrained("Deepaksai1/albert-fraud-detector-v2")
9
+ finbert_model = BertForSequenceClassification.from_pretrained("Deepaksai1/finbert-fraud-detector-v2").eval()
10
+ finbert_tokenizer = BertTokenizer.from_pretrained("Deepaksai1/finbert-fraud-detector-v2")
11
+
12
+ # Feature engineering function
13
+ def engineer_features(step, tx_type, amount, old_org, new_org, old_dest, new_dest):
14
+ # Calculate derived features
15
+ orig_diff = amount - (old_org - new_org)
16
+ dest_diff = (new_dest - old_dest) - amount
17
+ zero_balance = 1 if new_org == 0 else 0
18
+ amount_fraction = amount / old_org if old_org > 0 else 0
19
+
20
+ # Enhanced text representation with engineered features
21
+ text = (f"Step: {step}, Type: {tx_type}, Amount: {amount}, "
22
+ f"OldBalOrig: {old_org}, NewBalOrig: {new_org}, "
23
+ f"OldBalDest: {old_dest}, NewBalDest: {new_dest}, "
24
+ f"OrigDiff: {orig_diff}, DestDiff: {dest_diff}, "
25
+ f"ZeroBalance: {zero_balance}, AmountFraction: {amount_fraction}")
26
+
27
+ # Return text for transformer models and transaction metadata
28
+ metadata = {
29
+ 'amount': amount,
30
+ 'zero_balance': zero_balance,
31
+ 'orig_diff': orig_diff
32
+ }
33
+ return text, metadata
34
+
35
+ # Individual model prediction
36
+ def predict_single_model(text, model_name):
37
+ tokenizer = albert_tokenizer if model_name == "ALBERT" else finbert_tokenizer
38
+ model = albert_model if model_name == "ALBERT" else finbert_model
39
+
40
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
41
+ with torch.no_grad():
42
+ outputs = model(**inputs)
43
+ probs = F.softmax(outputs.logits, dim=1)
44
+ fraud_score = probs[0][1].item()
45
+
46
+ return fraud_score
47
+
48
+ # Ensemble prediction with adaptive thresholding
49
+ def ensemble_predict(step, tx_type, amount, old_org, new_org, old_dest, new_dest, use_ensemble=True):
50
+ # Engineer features
51
+ text, metadata = engineer_features(step, tx_type, amount, old_org, new_org, old_dest, new_dest)
52
+
53
+ # Get individual model predictions
54
+ albert_score = predict_single_model(text, "ALBERT")
55
+ finbert_score = predict_single_model(text, "FinBERT")
56
+
57
+ if use_ensemble:
58
+ # Weighted ensemble (ALBERT performs better so weighted higher)
59
+ weights = {"ALBERT": 0.6, "FinBERT": 0.4}
60
+ ensemble_score = weights["ALBERT"] * albert_score + weights["FinBERT"] * finbert_score
61
+
62
+ # Adaptive thresholding based on transaction characteristics
63
+ base_threshold = 0.5
64
+ if metadata['amount'] > 1000000: # High-value transaction
65
+ threshold = base_threshold - 0.1 # Lower threshold for high-risk
66
+ elif metadata['zero_balance'] == 1: # Account emptying
67
+ threshold = base_threshold - 0.15
68
+ elif abs(metadata['orig_diff']) > 1000: # Suspicious balance difference
69
+ threshold = base_threshold - 0.08
70
+ else:
71
+ threshold = base_threshold
72
+
73
+ is_fraud = ensemble_score > threshold
74
+ result = "Fraud" if is_fraud else "Not Fraud"
75
+
76
+ # Return individual scores as well for transparency
77
+ return result, ensemble_score, albert_score, finbert_score, threshold
78
+ else:
79
+ # For comparison, return individual model results
80
+ return "See individual scores", 0, albert_score, finbert_score, 0.5
81
+
82
+ # Gradio Interface
83
+ with gr.Blocks() as demo:
84
+ gr.Markdown("## 🔎 Advanced Hybrid Fraud Detection System")
85
+
86
+ with gr.Row():
87
+ step = gr.Number(label="Step", value=1)
88
+ tx_type = gr.Dropdown(choices=["CASH_OUT", "TRANSFER", "PAYMENT", "DEBIT", "CASH_IN"],
89
+ label="Transaction Type")
90
+ amount = gr.Number(label="Amount", value=0.0)
91
+
92
+ with gr.Row():
93
+ old_org = gr.Number(label="Old Balance Orig", value=0.0)
94
+ new_org = gr.Number(label="New Balance Orig", value=0.0)
95
+
96
+ with gr.Row():
97
+ old_dest = gr.Number(label="Old Balance Dest", value=0.0)
98
+ new_dest = gr.Number(label="New Balance Dest", value=0.0)
99
+
100
+ with gr.Row():
101
+ use_ensemble = gr.Checkbox(label="Use Ensemble Model", value=True)
102
+
103
+ with gr.Row():
104
+ predict_btn = gr.Button("Predict")
105
+
106
+ with gr.Row():
107
+ pred_label = gr.Label(label="Final Prediction")
108
+ ensemble_score = gr.Number(label="Ensemble Score")
109
+
110
+ with gr.Row():
111
+ albert_score = gr.Number(label="ALBERT Score")
112
+ finbert_score = gr.Number(label="FinBERT Score")
113
+ threshold = gr.Number(label="Applied Threshold")
114
+
115
+ # Bind function
116
+ predict_btn.click(
117
+ fn=ensemble_predict,
118
+ inputs=[step, tx_type, amount, old_org, new_org, old_dest, new_dest, use_ensemble],
119
+ outputs=[pred_label, ensemble_score, albert_score, finbert_score, threshold]
120
+ )
121
+
122
+ # Example transactions
123
+ examples = [
124
+ [151, "CASH_OUT", 1633227.0, 1633227.0, 0.0, 2865353.22, 4498580.23, True],
125
+ [353, "CASH_OUT", 174566.53, 174566.53, 0.0, 1191715.74, 1366282.27, True],
126
+ [357, "TRANSFER", 484493.06, 484493.06, 0.0, 0.0, 0.0, True],
127
+ [43, "CASH_OUT", 81571.63, 0.0, 0.0, 176194.2, 257765.83, True],
128
+ [307, "DEBIT", 247.82, 11544.0, 11296.18, 3550535.53, 3550783.36, True],
129
+ [350, "DEBIT", 4330.57, 3766.0, 0.0, 239435.41, 243765.98, True]
130
+ ]
131
+
132
+ gr.Examples(examples=examples,
133
+ inputs=[step, tx_type, amount, old_org, new_org, old_dest, new_dest, use_ensemble])
134
+
135
+ # Launch app
136
+ if __name__ == "__main__":
137
+ demo.launch()