Ujeshhh commited on
Commit
cf86c3a
·
verified ·
1 Parent(s): 155de20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +44 -37
app.py CHANGED
@@ -1,58 +1,65 @@
 
1
  import pandas as pd
2
  import joblib
3
- import gradio as gr
4
- import seaborn as sns
5
  import matplotlib.pyplot as plt
 
6
 
7
- # Load the trained model
8
  model = joblib.load("anomaly_detector_rf_model.pkl")
9
 
10
- # Define feature order
11
- feature_order = ['hour', 'day_of_week', 'is_weekend', 'amount', 'merchant_avg_amount',
12
- 'amount_zscore', 'log_amount', 'type_atm_withdrawal', 'type_credit',
13
- 'type_debit', 'merchant_encoded']
 
 
14
 
15
- def detect_anomalies(data):
16
- df = pd.DataFrame(data)
17
- df = df[feature_order] # Ensure correct feature order
18
- df['is_anomalous'] = model.predict(df)
 
 
 
 
 
 
 
 
 
 
19
 
20
- # Filter anomalies and display relevant details
21
- anomalies = df[df['is_anomalous'] == 1][['transaction_id', 'merchant', 'location', 'amount']]
22
- return anomalies
 
23
 
24
- # Function to generate plots
25
- def generate_plots(df):
26
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
27
-
28
- sns.countplot(data=df, x='is_anomalous', palette='Set2', ax=axes[0, 0])
29
- axes[0, 0].set_title("Anomaly Distribution")
30
-
31
- sns.countplot(data=df, y='merchant', order=df['merchant'].value_counts().index, palette='viridis', ax=axes[0, 1])
32
- axes[0, 1].set_title("Transactions by Merchant")
33
-
34
- sns.histplot(df['amount'], bins=30, kde=True, color='blue', ax=axes[1, 0])
35
- axes[1, 0].set_title("Transaction Amount Distribution")
36
-
37
- sns.scatterplot(data=df, x='amount', y='merchant_avg_amount', hue='is_anomalous', palette='coolwarm', ax=axes[1, 1])
38
- axes[1, 1].set_title("Amount vs. Merchant Average Amount")
39
-
40
  plt.tight_layout()
41
  return fig
42
 
43
  # Gradio Interface
44
- def app_interface(file):
45
- df = pd.read_csv(file.name)
46
  anomalies = detect_anomalies(df)
47
- plot = generate_plots(df)
48
- return anomalies, plot
 
49
 
50
  interface = gr.Interface(
51
  fn=app_interface,
52
- inputs=[gr.File(label="Upload Transaction Data (CSV)")],
53
- outputs=[gr.Dataframe(label="Detected Anomalies"), gr.Plot(label="Transaction Analysis Charts")],
54
- title="Financial Anomaly Detection",
55
- description="Upload a transaction dataset to detect financial anomalies and visualize transaction patterns."
56
  )
57
 
 
58
  interface.launch(share=True)
 
1
+ import gradio as gr
2
  import pandas as pd
3
  import joblib
 
 
4
  import matplotlib.pyplot as plt
5
+ import seaborn as sns
6
 
7
+ # Load trained model
8
  model = joblib.load("anomaly_detector_rf_model.pkl")
9
 
10
+ # Define feature columns used during training
11
+ feature_cols = [
12
+ "hour", "day_of_week", "is_weekend", "merchant_avg_amount",
13
+ "amount_zscore", "log_amount", "type_atm_withdrawal", "type_credit",
14
+ "type_debit", "merchant_encoded"
15
+ ]
16
 
17
+ # Function to detect anomalies
18
+ def detect_anomalies(df):
19
+ # Ensure only the required features are used & in correct order
20
+ missing_features = [col for col in feature_cols if col not in df.columns]
21
+ extra_features = [col for col in df.columns if col not in feature_cols]
22
+
23
+ if missing_features:
24
+ return f"Missing features: {missing_features}"
25
+
26
+ if extra_features:
27
+ df = df[feature_cols] # Select only relevant columns
28
+
29
+ # Make predictions
30
+ df["is_anomalous"] = model.predict(df)
31
 
32
+ # Filter anomalies
33
+ anomalies = df[df["is_anomalous"] == 1]
34
+
35
+ return anomalies[["transaction_id", "merchant", "location", "amount"]]
36
 
37
+ # Function to visualize anomalies
38
+ def plot_charts(df):
39
  fig, axes = plt.subplots(2, 2, figsize=(12, 10))
40
+ sns.histplot(df["amount"], bins=30, kde=True, ax=axes[0, 0])
41
+ sns.boxplot(x=df["amount"], ax=axes[0, 1])
42
+ sns.countplot(x=df["day_of_week"], ax=axes[1, 0])
43
+ sns.barplot(x=df["merchant"], y=df["amount"], ax=axes[1, 1])
44
+
 
 
 
 
 
 
 
 
45
  plt.tight_layout()
46
  return fig
47
 
48
  # Gradio Interface
49
+ def app_interface(csv_file):
50
+ df = pd.read_csv(csv_file)
51
  anomalies = detect_anomalies(df)
52
+ fig = plot_charts(df)
53
+
54
+ return anomalies, fig
55
 
56
  interface = gr.Interface(
57
  fn=app_interface,
58
+ inputs="file",
59
+ outputs=["dataframe", "plot"],
60
+ title="Financial Anomaly Detector",
61
+ description="Upload a transaction CSV file to detect fraudulent transactions."
62
  )
63
 
64
+ # Launch the Gradio app with public access
65
  interface.launch(share=True)