geekyrakshit commited on
Commit
dfbca8a
1 Parent(s): a202ba5

add: visualization of ROC curve + score distribution

Browse files
Files changed (1) hide show
  1. guardrails_genie/train/llama_guard.py +77 -17
guardrails_genie/train/llama_guard.py CHANGED
@@ -1,4 +1,4 @@
1
- import matplotlib.pyplot as plt
2
  import streamlit as st
3
  import torch
4
  import torch.nn.functional as F
@@ -98,28 +98,87 @@ class LlamaGuardFineTuner:
98
  return scores
99
 
100
  def visualize_roc_curve(self, test_scores: list[float]):
101
- plt.figure(figsize=(8, 6))
102
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
103
  fpr, tpr, _ = roc_curve(test_labels, test_scores)
104
  roc_auc = roc_auc_score(test_labels, test_scores)
105
- plt.plot(
106
- fpr,
107
- tpr,
108
- color="darkorange",
109
- lw=2,
110
- label=f"ROC curve (area = {roc_auc:.3f})",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
111
  )
112
- plt.plot([0, 1], [0, 1], color="navy", lw=2, linestyle="--")
113
- plt.xlim([0.0, 1.0])
114
- plt.ylim([0.0, 1.05])
115
- plt.xlabel("False Positive Rate")
116
- plt.ylabel("True Positive Rate")
117
- plt.title("Receiver Operating Characteristic")
118
- plt.legend(loc="lower right")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
119
  if self.streamlit_mode:
120
- st.pyplot(plt)
121
  else:
122
- plt.show()
123
 
124
  def evaluate_model(
125
  self,
@@ -138,4 +197,5 @@ class LlamaGuardFineTuner:
138
  max_length=max_length,
139
  )
140
  self.visualize_roc_curve(test_scores)
 
141
  return test_scores
 
1
+ import plotly.graph_objects as go
2
  import streamlit as st
3
  import torch
4
  import torch.nn.functional as F
 
98
  return scores
99
 
100
  def visualize_roc_curve(self, test_scores: list[float]):
 
101
  test_labels = [int(elt) for elt in self.test_dataset["label"]]
102
  fpr, tpr, _ = roc_curve(test_labels, test_scores)
103
  roc_auc = roc_auc_score(test_labels, test_scores)
104
+
105
+ fig = go.Figure()
106
+ fig.add_trace(
107
+ go.Scatter(
108
+ x=fpr,
109
+ y=tpr,
110
+ mode="lines",
111
+ name=f"ROC curve (area = {roc_auc:.3f})",
112
+ line=dict(color="darkorange", width=2),
113
+ )
114
+ )
115
+ fig.add_trace(
116
+ go.Scatter(
117
+ x=[0, 1],
118
+ y=[0, 1],
119
+ mode="lines",
120
+ name="Random Guess",
121
+ line=dict(color="navy", width=2, dash="dash"),
122
+ )
123
+ )
124
+
125
+ fig.update_layout(
126
+ title="Receiver Operating Characteristic",
127
+ xaxis_title="False Positive Rate",
128
+ yaxis_title="True Positive Rate",
129
+ xaxis=dict(range=[0.0, 1.0]),
130
+ yaxis=dict(range=[0.0, 1.05]),
131
+ legend=dict(x=0.8, y=0.2),
132
  )
133
+
134
+ if self.streamlit_mode:
135
+ st.plotly_chart(fig)
136
+ else:
137
+ fig.show()
138
+
139
+ def visualize_score_distribution(self, scores: list[float]):
140
+ test_labels = [int(elt) for elt in self.test_dataset["label"]]
141
+ positive_scores = [scores[i] for i in range(500) if test_labels[i] == 1]
142
+ negative_scores = [scores[i] for i in range(500) if test_labels[i] == 0]
143
+
144
+ fig = go.Figure()
145
+
146
+ # Plotting positive scores
147
+ fig.add_trace(
148
+ go.Histogram(
149
+ x=positive_scores,
150
+ histnorm="probability density",
151
+ name="Positive",
152
+ marker_color="darkblue",
153
+ opacity=0.75,
154
+ )
155
+ )
156
+
157
+ # Plotting negative scores
158
+ fig.add_trace(
159
+ go.Histogram(
160
+ x=negative_scores,
161
+ histnorm="probability density",
162
+ name="Negative",
163
+ marker_color="darkred",
164
+ opacity=0.75,
165
+ )
166
+ )
167
+
168
+ # Updating layout
169
+ fig.update_layout(
170
+ title="Score Distribution for Positive and Negative Examples",
171
+ xaxis_title="Score",
172
+ yaxis_title="Density",
173
+ barmode="overlay",
174
+ legend_title="Scores",
175
+ )
176
+
177
+ # Display the plot
178
  if self.streamlit_mode:
179
+ st.plotly_chart(fig)
180
  else:
181
+ fig.show()
182
 
183
  def evaluate_model(
184
  self,
 
197
  max_length=max_length,
198
  )
199
  self.visualize_roc_curve(test_scores)
200
+ self.visualize_score_distribution(test_scores)
201
  return test_scores