geekyrakshit commited on
Commit
785c044
·
1 Parent(s): af688eb

add: limited eval table rendering in app

Browse files
application_pages/evaluation_app.py CHANGED
@@ -1,4 +1,6 @@
1
  import asyncio
 
 
2
  from importlib import import_module
3
 
4
  import pandas as pd
@@ -9,12 +11,11 @@ from dotenv import load_dotenv
9
  from guardrails_genie.guardrails import GuardrailManager
10
  from guardrails_genie.llm import OpenAIModel
11
  from guardrails_genie.metrics import AccuracyMetric
12
-
13
- load_dotenv()
14
- weave.init(project_name="guardrails-genie")
15
 
16
 
17
  def initialize_session_state():
 
18
  if "uploaded_file" not in st.session_state:
19
  st.session_state.uploaded_file = None
20
  if "dataset_name" not in st.session_state:
@@ -35,6 +36,18 @@ def initialize_session_state():
35
  st.session_state.evaluation_summary = None
36
  if "guardrail_manager" not in st.session_state:
37
  st.session_state.guardrail_manager = None
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
 
40
  def initialize_guardrail():
@@ -107,6 +120,8 @@ if st.session_state.dataset_previewed:
107
 
108
  if st.session_state.guardrail_names != []:
109
  initialize_guardrail()
 
 
110
  if st.session_state.guardrail_manager is not None:
111
  if st.sidebar.button("Start Evaluation"):
112
  st.session_state.start_evaluation = True
@@ -119,10 +134,54 @@ if st.session_state.dataset_previewed:
119
  with st.expander("Evaluation Results", expanded=True):
120
  evaluation_summary, call = asyncio.run(
121
  evaluation.evaluate.call(
122
- evaluation, st.session_state.guardrail_manager
 
 
 
 
 
123
  )
124
  )
125
- st.markdown(f"[Explore evaluation in Weave]({call.ui_url})")
126
- st.write(evaluation_summary)
127
- st.session_state.evaluation_summary = evaluation_summary
128
- st.session_state.start_evaluation = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import asyncio
2
+ import os
3
+ import time
4
  from importlib import import_module
5
 
6
  import pandas as pd
 
11
  from guardrails_genie.guardrails import GuardrailManager
12
  from guardrails_genie.llm import OpenAIModel
13
  from guardrails_genie.metrics import AccuracyMetric
14
+ from guardrails_genie.utils import EvaluationCallManager
 
 
15
 
16
 
17
  def initialize_session_state():
18
+ load_dotenv()
19
  if "uploaded_file" not in st.session_state:
20
  st.session_state.uploaded_file = None
21
  if "dataset_name" not in st.session_state:
 
36
  st.session_state.evaluation_summary = None
37
  if "guardrail_manager" not in st.session_state:
38
  st.session_state.guardrail_manager = None
39
+ if "evaluation_name" not in st.session_state:
40
+ st.session_state.evaluation_name = ""
41
+ if "show_result_table" not in st.session_state:
42
+ st.session_state.show_result_table = False
43
+ if "weave_client" not in st.session_state:
44
+ st.session_state.weave_client = weave.init(
45
+ project_name=os.getenv("WEAVE_PROJECT")
46
+ )
47
+ if "evaluation_call_manager" not in st.session_state:
48
+ st.session_state.evaluation_call_manager = None
49
+ if "call_id" not in st.session_state:
50
+ st.session_state.call_id = None
51
 
52
 
53
  def initialize_guardrail():
 
120
 
121
  if st.session_state.guardrail_names != []:
122
  initialize_guardrail()
123
+ evaluation_name = st.sidebar.text_input("Evaluation name", value="")
124
+ st.session_state.evaluation_name = evaluation_name
125
  if st.session_state.guardrail_manager is not None:
126
  if st.sidebar.button("Start Evaluation"):
127
  st.session_state.start_evaluation = True
 
134
  with st.expander("Evaluation Results", expanded=True):
135
  evaluation_summary, call = asyncio.run(
136
  evaluation.evaluate.call(
137
+ evaluation,
138
+ st.session_state.guardrail_manager,
139
+ __weave={
140
+ "display_name": "Evaluation.evaluate:"
141
+ + st.session_state.evaluation_name
142
+ },
143
  )
144
  )
145
+ x_axis = list(evaluation_summary["AccuracyMetric"].keys())
146
+ y_axis = [
147
+ evaluation_summary["AccuracyMetric"][x_axis_item]
148
+ for x_axis_item in x_axis
149
+ ]
150
+ st.bar_chart(
151
+ pd.DataFrame({"Metric": x_axis, "Score": y_axis}),
152
+ x="Metric",
153
+ y="Score",
154
+ )
155
+ st.session_state.evaluation_summary = evaluation_summary
156
+ st.session_state.call_id = call.id
157
+ st.session_state.start_evaluation = False
158
+
159
+ if not st.session_state.start_evaluation:
160
+ time.sleep(5)
161
+ st.session_state.evaluation_call_manager = (
162
+ EvaluationCallManager(
163
+ entity="geekyrakshit",
164
+ project="guardrails-genie",
165
+ call_id=st.session_state.call_id,
166
+ )
167
+ )
168
+ for guardrail_name in st.session_state.guardrail_names:
169
+ st.session_state.evaluation_call_manager.call_list.append(
170
+ {
171
+ "guardrail_name": guardrail_name,
172
+ "calls": st.session_state.evaluation_call_manager.collect_guardrail_guard_calls_from_eval(
173
+ call=call
174
+ ),
175
+ }
176
+ )
177
+ st.dataframe(
178
+ st.session_state.evaluation_call_manager.render_calls_to_streamlit()
179
+ )
180
+ if st.session_state.evaluation_call_manager.show_warning_in_app:
181
+ st.warning(
182
+ f"Only {st.session_state.evaluation_call_manager.max_count} calls can be shown in the app."
183
+ )
184
+ st.markdown(
185
+ f"Explore the entire evaluation trace table in [Weave]({call.ui_url})"
186
+ )
187
+ st.session_state.evaluation_call_manager = None
guardrails_genie/utils.py CHANGED
@@ -1,7 +1,9 @@
1
  import os
2
 
 
3
  import pymupdf4llm
4
  import weave
 
5
  from firerequests import FireRequests
6
 
7
 
@@ -11,3 +13,44 @@ def get_markdown_from_pdf_url(url: str) -> str:
11
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
12
  os.remove("temp.pdf")
13
  return markdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import os
2
 
3
+ import pandas as pd
4
  import pymupdf4llm
5
  import weave
6
+ import weave.trace
7
  from firerequests import FireRequests
8
 
9
 
 
13
  markdown = pymupdf4llm.to_markdown("temp.pdf", show_progress=False)
14
  os.remove("temp.pdf")
15
  return markdown
16
+
17
+
18
+ class EvaluationCallManager:
19
+ def __init__(self, entity: str, project: str, call_id: str, max_count: int = 10):
20
+ self.base_call = weave.init(f"{entity}/{project}").get_call(call_id=call_id)
21
+ self.max_count = max_count
22
+ self.show_warning_in_app = False
23
+ self.call_list = []
24
+
25
+ def collect_guardrail_guard_calls_from_eval(self, call):
26
+ guard_calls, count = [], 0
27
+ for eval_predict_call in call.children():
28
+ if "Evaluation.summarize" in eval_predict_call._op_name:
29
+ break
30
+ required_call = eval_predict_call.children()[0].children()[0].children()[0]
31
+ guard_calls.append(
32
+ {
33
+ "input_prompt": str(required_call.inputs["prompt"]),
34
+ "outputs": dict(required_call.output),
35
+ }
36
+ )
37
+ count += 1
38
+ if count >= self.max_count:
39
+ self.show_warning_in_app = True
40
+ break
41
+ return guard_calls
42
+
43
+ def render_calls_to_streamlit(self):
44
+ dataframe = {
45
+ "input_prompt": [
46
+ call["input_prompt"] for call in self.call_list[0]["calls"]
47
+ ]
48
+ }
49
+ for guardrail_call in self.call_list:
50
+ dataframe[guardrail_call["guardrail_name"] + ".safe"] = [
51
+ call["outputs"]["safe"] for call in guardrail_call["calls"]
52
+ ]
53
+ dataframe[guardrail_call["guardrail_name"] + ".summary"] = [
54
+ call["outputs"]["summary"] for call in guardrail_call["calls"]
55
+ ]
56
+ return pd.DataFrame(dataframe)