XiangJinYu commited on
Commit
0a59e16
·
verified ·
1 Parent(s): ae33945

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +283 -0
app.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ from pathlib import Path
3
+ from typing import Dict
4
+
5
+ import streamlit as st
6
+ import yaml
7
+ from loguru import logger as _logger
8
+
9
+ from metagpt.const import METAGPT_ROOT
10
+ from metagpt.ext.spo.components.optimizer import PromptOptimizer
11
+ from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType
12
+
13
+
14
+ def load_yaml_template(template_path: Path) -> Dict:
15
+ if template_path.exists():
16
+ with open(template_path, "r", encoding="utf-8") as f:
17
+ return yaml.safe_load(f)
18
+ return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]}
19
+
20
+
21
+ def save_yaml_template(template_path: Path, data: Dict) -> None:
22
+ template_format = {
23
+ "prompt": str(data.get("prompt", "")),
24
+ "requirements": str(data.get("requirements", "")),
25
+ "count": data.get("count"),
26
+ "qa": [
27
+ {"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()}
28
+ for qa in data.get("qa", [])
29
+ ],
30
+ }
31
+
32
+ template_path.parent.mkdir(parents=True, exist_ok=True)
33
+
34
+ with open(template_path, "w", encoding="utf-8") as f:
35
+ yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
36
+
37
+
38
+ def display_optimization_results(result_data):
39
+ for result in result_data:
40
+ round_num = result["round"]
41
+ success = result["succeed"]
42
+ prompt = result["prompt"]
43
+
44
+ with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
45
+ st.markdown("**Prompt:**")
46
+ st.code(prompt, language="text")
47
+ st.markdown("<br>", unsafe_allow_html=True)
48
+
49
+ col1, col2 = st.columns(2)
50
+ with col1:
51
+ st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
52
+ with col2:
53
+ st.markdown(f"**Tokens:** {result['tokens']}")
54
+
55
+ st.markdown("**Answers:**")
56
+ for idx, answer in enumerate(result["answers"]):
57
+ st.markdown(f"**Question {idx + 1}:**")
58
+ st.text(answer["question"])
59
+ st.markdown("**Answer:**")
60
+ st.text(answer["answer"])
61
+ st.markdown("---")
62
+
63
+ # Summary
64
+ success_count = sum(1 for r in result_data if r["succeed"])
65
+ total_rounds = len(result_data)
66
+
67
+ st.markdown("### Summary")
68
+ col1, col2 = st.columns(2)
69
+ with col1:
70
+ st.metric("Total Rounds", total_rounds)
71
+ with col2:
72
+ st.metric("Successful Rounds", success_count)
73
+
74
+
75
+ def main():
76
+ if "optimization_results" not in st.session_state:
77
+ st.session_state.optimization_results = []
78
+
79
+ st.title("SPO | Self-Supervised Prompt Optimization 🤖")
80
+
81
+ # Sidebar for configurations
82
+ with st.sidebar:
83
+ st.header("Configuration")
84
+
85
+ # Template Selection/Creation
86
+ settings_path = Path("metagpt/ext/spo/settings")
87
+ existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
88
+
89
+ template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
90
+
91
+ if template_mode == "Use Existing":
92
+ template_name = st.selectbox("Select Template", existing_templates)
93
+ else:
94
+ template_name = st.text_input("New Template Name")
95
+ if template_name and not template_name.endswith(".yaml"):
96
+ template_name = f"{template_name}"
97
+
98
+ # LLM Settings
99
+ st.subheader("LLM Settings")
100
+ opt_model = st.selectbox(
101
+ "Optimization Model", ["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"], index=0
102
+ )
103
+ opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7)
104
+
105
+ eval_model = st.selectbox(
106
+ "Evaluation Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0
107
+ )
108
+ eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3)
109
+
110
+ exec_model = st.selectbox(
111
+ "Execution Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0
112
+ )
113
+ exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0)
114
+
115
+ # Optimizer Settings
116
+ st.subheader("Optimizer Settings")
117
+ initial_round = st.number_input("Initial Round", 1, 100, 1)
118
+ max_rounds = st.number_input("Maximum Rounds", 1, 100, 10)
119
+
120
+ # Main content area
121
+ st.header("Template Configuration")
122
+
123
+ if template_name:
124
+ template_path = settings_path / f"{template_name}.yaml"
125
+ template_data = load_yaml_template(template_path)
126
+
127
+ if "current_template" not in st.session_state or st.session_state.current_template != template_name:
128
+ st.session_state.current_template = template_name
129
+ st.session_state.qas = template_data.get("qa", [])
130
+
131
+ # Edit template sections
132
+ prompt = st.text_area("Prompt", template_data.get("prompt", ""), height=100)
133
+ requirements = st.text_area("Requirements", template_data.get("requirements", ""), height=100)
134
+
135
+ # qa section
136
+ st.subheader("Q&A Examples")
137
+
138
+ # Add new qa button
139
+ if st.button("Add New Q&A"):
140
+ st.session_state.qas.append({"question": "", "answer": ""})
141
+
142
+ # Edit qas
143
+ new_qas = []
144
+ for i in range(len(st.session_state.qas)):
145
+ st.markdown(f"**QA #{i + 1}**")
146
+ col1, col2, col3 = st.columns([45, 45, 10])
147
+
148
+ with col1:
149
+ question = st.text_area(
150
+ f"Question {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
151
+ )
152
+ with col2:
153
+ answer = st.text_area(
154
+ f"Answer {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
155
+ )
156
+ with col3:
157
+ if st.button("🗑️", key=f"delete_{i}"):
158
+ st.session_state.qas.pop(i)
159
+ st.rerun()
160
+
161
+ new_qas.append({"question": question, "answer": answer})
162
+
163
+ # Save template button
164
+ if st.button("Save Template"):
165
+ new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
166
+
167
+ save_yaml_template(template_path, new_template_data)
168
+
169
+ st.session_state.qas = new_qas
170
+ st.success(f"Template saved to {template_path}")
171
+
172
+ st.subheader("Current Template Preview")
173
+ preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
174
+ st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
175
+
176
+ st.subheader("Optimization Logs")
177
+ log_container = st.empty()
178
+
179
+ class StreamlitSink:
180
+ def write(self, message):
181
+ current_logs = st.session_state.get("logs", [])
182
+ current_logs.append(message.strip())
183
+ st.session_state.logs = current_logs
184
+
185
+ log_container.code("\n".join(current_logs), language="plaintext")
186
+
187
+ streamlit_sink = StreamlitSink()
188
+ _logger.remove()
189
+
190
+ def prompt_optimizer_filter(record):
191
+ return "optimizer" in record["name"].lower()
192
+
193
+ _logger.add(
194
+ streamlit_sink.write,
195
+ format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}",
196
+ filter=prompt_optimizer_filter,
197
+ )
198
+ _logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
199
+
200
+ # Start optimization button
201
+ if st.button("Start Optimization"):
202
+ try:
203
+ # Initialize LLM
204
+ SPO_LLM.initialize(
205
+ optimize_kwargs={"model": opt_model, "temperature": opt_temp},
206
+ evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
207
+ execute_kwargs={"model": exec_model, "temperature": exec_temp},
208
+ )
209
+
210
+ # Create optimizer instance
211
+ optimizer = PromptOptimizer(
212
+ optimized_path="workspace",
213
+ initial_round=initial_round,
214
+ max_rounds=max_rounds,
215
+ template=f"{template_name}.yaml",
216
+ name=template_name,
217
+ )
218
+
219
+ # Run optimization with progress bar
220
+ with st.spinner("Optimizing prompts..."):
221
+ optimizer.optimize()
222
+
223
+ st.success("Optimization completed!")
224
+
225
+ st.header("Optimization Results")
226
+
227
+ prompt_path = optimizer.root_path / "prompts"
228
+ result_data = optimizer.data_utils.load_results(prompt_path)
229
+
230
+ st.session_state.optimization_results = result_data
231
+
232
+ except Exception as e:
233
+ st.error(f"An error occurred: {str(e)}")
234
+ _logger.error(f"Error during optimization: {str(e)}")
235
+
236
+ if st.session_state.optimization_results:
237
+ st.header("Optimization Results")
238
+ display_optimization_results(st.session_state.optimization_results)
239
+
240
+ st.markdown("---")
241
+ st.subheader("Test Optimized Prompt")
242
+ col1, col2 = st.columns(2)
243
+
244
+ with col1:
245
+ test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
246
+
247
+ with col2:
248
+ test_question = st.text_area("Your Question", value="", height=200, key="test_question")
249
+
250
+ if st.button("Test Prompt"):
251
+ if test_prompt and test_question:
252
+ try:
253
+ with st.spinner("Generating response..."):
254
+ SPO_LLM.initialize(
255
+ optimize_kwargs={"model": opt_model, "temperature": opt_temp},
256
+ evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
257
+ execute_kwargs={"model": exec_model, "temperature": exec_temp},
258
+ )
259
+
260
+ llm = SPO_LLM.get_instance()
261
+ messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
262
+
263
+ async def get_response():
264
+ return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
265
+
266
+ loop = asyncio.new_event_loop()
267
+ asyncio.set_event_loop(loop)
268
+ try:
269
+ response = loop.run_until_complete(get_response())
270
+ finally:
271
+ loop.close()
272
+
273
+ st.subheader("Response:")
274
+ st.markdown(response)
275
+
276
+ except Exception as e:
277
+ st.error(f"Error generating response: {str(e)}")
278
+ else:
279
+ st.warning("Please enter both prompt and question.")
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()