Deepaksai1 commited on
Commit
3c970f0
·
verified ·
1 Parent(s): d5c8e54

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import AlbertTokenizer, AlbertForSequenceClassification, BertTokenizer, BertForSequenceClassification
4
+ import numpy as np
5
+ import joblib
6
+ import os
7
+
8
+ # Load ALBERT model
9
+ albert_model = AlbertForSequenceClassification.from_pretrained("Deepaksai1/albert-fraud-detector")
10
+ albert_tokenizer = AlbertTokenizer.from_pretrained("Deepaksai1/albert-fraud-detector")
11
+ albert_model.eval()
12
+
13
+ # Load FinBERT model
14
+ finbert_model = BertForSequenceClassification.from_pretrained("Deepaksai1/finbert-fraud-detector")
15
+ finbert_tokenizer = BertTokenizer.from_pretrained("Deepaksai1/finbert-fraud-detector")
16
+ finbert_model.eval()
17
+
18
+ # Load CatBoost model
19
+ from catboost import CatBoostClassifier
20
+ catboost_model_path = "catboost_fraud_model.cbm"
21
+ catboost_model = CatBoostClassifier()
22
+ catboost_model.load_model(catboost_model_path)
23
+
24
+ # CatBoost prediction (expects structured features, here we simulate with dummy value)
25
+ def predict_with_catboost(text):
26
+ # Simulate with simple heuristic
27
+ amount = float([s for s in text.split(',') if 'Amount' in s][0].split(':')[1].strip())
28
+ prediction = catboost_model.predict([[amount]])[0]
29
+ proba = catboost_model.predict_proba([[amount]])[0][1]
30
+ return ("Fraud" if prediction == 1 else "Not Fraud"), float(proba)
31
+
32
+ # ALBERT prediction
33
+ def predict_with_albert(text):
34
+ inputs = albert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
35
+ with torch.no_grad():
36
+ outputs = albert_model(**inputs)
37
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
38
+ pred_class = torch.argmax(probs).item()
39
+ pred_prob = probs[0][1].item()
40
+ return ("Fraud" if pred_class == 1 else "Not Fraud"), float(pred_prob)
41
+
42
+ # FinBERT prediction
43
+ def predict_with_finbert(text):
44
+ inputs = finbert_tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=64)
45
+ with torch.no_grad():
46
+ outputs = finbert_model(**inputs)
47
+ probs = torch.nn.functional.softmax(outputs.logits, dim=1)
48
+ pred_class = torch.argmax(probs).item()
49
+ pred_prob = probs[0][1].item()
50
+ return ("Fraud" if pred_class == 1 else "Not Fraud"), float(pred_prob)
51
+
52
+ # Main prediction selector
53
+ def predict(text, model_name):
54
+ if model_name == "ALBERT":
55
+ return predict_with_albert(text)
56
+ elif model_name == "FinBERT":
57
+ return predict_with_finbert(text)
58
+ elif model_name == "CatBoost":
59
+ return predict_with_catboost(text)
60
+ else:
61
+ return "Unknown Model", 0.0
62
+
63
+ # Example transactions
64
+ examples = [
65
+ "Step: 305, Type: CASH_OUT, Amount: 2321633.57, Origin Balance: 2321633.57, Dest Balance: 0.0",
66
+ "Step: 6, Type: CASH_OUT, Amount: 13704.0, Origin Balance: 13704.0, Dest Balance: 3382.84",
67
+ "Step: 285, Type: TRANSFER, Amount: 125487.45, Origin Balance: 0.0, Dest Balance: 524556.64",
68
+ "Step: 352, Type: PAYMENT, Amount: 41263.42, Origin Balance: 0.0, Dest Balance: 0.0",
69
+ "Step: 372, Type: CASH_IN, Amount: 187503.32, Origin Balance: 76827.0, Dest Balance: 0.0"
70
+ ]
71
+
72
+ # Gradio interface
73
+ gui = gr.Interface(
74
+ fn=predict,
75
+ inputs=[
76
+ gr.Textbox(label="Enter Transaction Description"),
77
+ gr.Dropdown(choices=["ALBERT", "FinBERT", "CatBoost"], label="Select Model", value="ALBERT")
78
+ ],
79
+ outputs=[
80
+ gr.Label(label="Prediction"),
81
+ gr.Number(label="Fraud Probability")
82
+ ],
83
+ examples=[[ex, "ALBERT"] for ex in examples],
84
+ title="💸 Fraud Detection Assistant (ALBERT, FinBERT, CatBoost)",
85
+ description="Analyze transaction text for fraud using your choice of model."
86
+ )
87
+
88
+ # Launch
89
+ if __name__ == "__main__":
90
+ gui.launch()