rtik007 commited on
Commit
7918a66
·
verified ·
1 Parent(s): a9b33b5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -0
app.py CHANGED
@@ -50,6 +50,45 @@ shap_values = explainer(df[columns])
50
 
51
  # Define functions for Gradio
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  def get_roc_curve():
54
  """Generates the ROC curve plot."""
55
  fpr, tpr, _ = roc_curve(true_labels, -df["Anomaly_Score"]) # Use -scores as higher scores mean normal
@@ -81,6 +120,27 @@ def get_anomaly_samples():
81
  with gr.Blocks() as demo:
82
  gr.Markdown("# Isolation Forest Anomaly Detection")
83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  with gr.Tab("Anomaly Samples"):
85
  gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Top 10 Records (Anomalies)</h3>")
86
  top_table = gr.Dataframe(label="Top 10 Records")
 
50
 
51
  # Define functions for Gradio
52
 
53
+ def get_shap_summary():
54
+ """Generates SHAP summary plot."""
55
+ plt.figure()
56
+ shap.summary_plot(shap_values, df[columns], feature_names=columns, show=False)
57
+ plt.savefig("shap_summary.png")
58
+ return "shap_summary.png"
59
+
60
+ def get_shap_waterfall(index):
61
+ """Generates SHAP waterfall plot for a specific data point."""
62
+ specific_index = int(index)
63
+ plt.figure()
64
+ shap.waterfall_plot(
65
+ shap.Explanation(
66
+ values=shap_values.values[specific_index],
67
+ base_values=shap_values.base_values[specific_index],
68
+ data=df.iloc[specific_index],
69
+ feature_names=columns
70
+ )
71
+ )
72
+ plt.savefig("shap_waterfall.png")
73
+ return "shap_waterfall.png"
74
+
75
+ def get_scatter_plot(feature1, feature2):
76
+ """Generates scatter plot for two features."""
77
+ plt.figure(figsize=(8, 6))
78
+ plt.scatter(
79
+ df[feature1],
80
+ df[feature2],
81
+ c=(df["Anomaly_Label"] == "Anomaly"),
82
+ cmap="coolwarm",
83
+ edgecolor="k",
84
+ alpha=0.7
85
+ )
86
+ plt.title(f"Isolation Forest - {feature1} vs {feature2}")
87
+ plt.xlabel(feature1)
88
+ plt.ylabel(feature2)
89
+ plt.savefig("scatter_plot.png")
90
+ return "scatter_plot.png"
91
+
92
  def get_roc_curve():
93
  """Generates the ROC curve plot."""
94
  fpr, tpr, _ = roc_curve(true_labels, -df["Anomaly_Score"]) # Use -scores as higher scores mean normal
 
120
  with gr.Blocks() as demo:
121
  gr.Markdown("# Isolation Forest Anomaly Detection")
122
 
123
+ with gr.Tab("SHAP Summary"):
124
+ gr.Markdown("### Global Explainability: SHAP Summary Plot")
125
+ shap_button = gr.Button("Generate SHAP Summary Plot")
126
+ shap_image = gr.Image()
127
+ shap_button.click(get_shap_summary, outputs=shap_image)
128
+
129
+ with gr.Tab("SHAP Waterfall"):
130
+ gr.Markdown("### Local Explainability: SHAP Waterfall Plot")
131
+ index_input = gr.Number(label="Data Point Index", value=0)
132
+ shap_waterfall_button = gr.Button("Generate SHAP Waterfall Plot")
133
+ shap_waterfall_image = gr.Image()
134
+ shap_waterfall_button.click(get_shap_waterfall, inputs=index_input, outputs=shap_waterfall_image)
135
+
136
+ with gr.Tab("Feature Scatter Plot"):
137
+ gr.Markdown("### Feature Interaction: Scatter Plot")
138
+ feature1_dropdown = gr.Dropdown(choices=columns, label="Feature 1")
139
+ feature2_dropdown = gr.Dropdown(choices=columns, label="Feature 2")
140
+ scatter_button = gr.Button("Generate Scatter Plot")
141
+ scatter_image = gr.Image()
142
+ scatter_button.click(get_scatter_plot, inputs=[feature1_dropdown, feature2_dropdown], outputs=scatter_image)
143
+
144
  with gr.Tab("Anomaly Samples"):
145
  gr.HTML("<h3 style='text-align: center; font-size: 18px; font-weight: bold;'>Top 10 Records (Anomalies)</h3>")
146
  top_table = gr.Dataframe(label="Top 10 Records")