Spaces:
Running
Running
Upload app.py
Browse files
app.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import asyncio
|
2 |
+
from pathlib import Path
|
3 |
+
from typing import Dict
|
4 |
+
|
5 |
+
import streamlit as st
|
6 |
+
import yaml
|
7 |
+
from loguru import logger as _logger
|
8 |
+
|
9 |
+
from metagpt.const import METAGPT_ROOT
|
10 |
+
from metagpt.ext.spo.components.optimizer import PromptOptimizer
|
11 |
+
from metagpt.ext.spo.utils.llm_client import SPO_LLM, RequestType
|
12 |
+
|
13 |
+
|
14 |
+
def load_yaml_template(template_path: Path) -> Dict:
|
15 |
+
if template_path.exists():
|
16 |
+
with open(template_path, "r", encoding="utf-8") as f:
|
17 |
+
return yaml.safe_load(f)
|
18 |
+
return {"prompt": "", "requirements": "", "count": None, "qa": [{"question": "", "answer": ""}]}
|
19 |
+
|
20 |
+
|
21 |
+
def save_yaml_template(template_path: Path, data: Dict) -> None:
|
22 |
+
template_format = {
|
23 |
+
"prompt": str(data.get("prompt", "")),
|
24 |
+
"requirements": str(data.get("requirements", "")),
|
25 |
+
"count": data.get("count"),
|
26 |
+
"qa": [
|
27 |
+
{"question": str(qa.get("question", "")).strip(), "answer": str(qa.get("answer", "")).strip()}
|
28 |
+
for qa in data.get("qa", [])
|
29 |
+
],
|
30 |
+
}
|
31 |
+
|
32 |
+
template_path.parent.mkdir(parents=True, exist_ok=True)
|
33 |
+
|
34 |
+
with open(template_path, "w", encoding="utf-8") as f:
|
35 |
+
yaml.dump(template_format, f, allow_unicode=True, sort_keys=False, default_flow_style=False, indent=2)
|
36 |
+
|
37 |
+
|
38 |
+
def display_optimization_results(result_data):
|
39 |
+
for result in result_data:
|
40 |
+
round_num = result["round"]
|
41 |
+
success = result["succeed"]
|
42 |
+
prompt = result["prompt"]
|
43 |
+
|
44 |
+
with st.expander(f"Round {round_num} {':white_check_mark:' if success else ':x:'}"):
|
45 |
+
st.markdown("**Prompt:**")
|
46 |
+
st.code(prompt, language="text")
|
47 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
48 |
+
|
49 |
+
col1, col2 = st.columns(2)
|
50 |
+
with col1:
|
51 |
+
st.markdown(f"**Status:** {'Success ✅ ' if success else 'Failed ❌ '}")
|
52 |
+
with col2:
|
53 |
+
st.markdown(f"**Tokens:** {result['tokens']}")
|
54 |
+
|
55 |
+
st.markdown("**Answers:**")
|
56 |
+
for idx, answer in enumerate(result["answers"]):
|
57 |
+
st.markdown(f"**Question {idx + 1}:**")
|
58 |
+
st.text(answer["question"])
|
59 |
+
st.markdown("**Answer:**")
|
60 |
+
st.text(answer["answer"])
|
61 |
+
st.markdown("---")
|
62 |
+
|
63 |
+
# Summary
|
64 |
+
success_count = sum(1 for r in result_data if r["succeed"])
|
65 |
+
total_rounds = len(result_data)
|
66 |
+
|
67 |
+
st.markdown("### Summary")
|
68 |
+
col1, col2 = st.columns(2)
|
69 |
+
with col1:
|
70 |
+
st.metric("Total Rounds", total_rounds)
|
71 |
+
with col2:
|
72 |
+
st.metric("Successful Rounds", success_count)
|
73 |
+
|
74 |
+
|
75 |
+
def main():
|
76 |
+
if "optimization_results" not in st.session_state:
|
77 |
+
st.session_state.optimization_results = []
|
78 |
+
|
79 |
+
st.title("SPO | Self-Supervised Prompt Optimization 🤖")
|
80 |
+
|
81 |
+
# Sidebar for configurations
|
82 |
+
with st.sidebar:
|
83 |
+
st.header("Configuration")
|
84 |
+
|
85 |
+
# Template Selection/Creation
|
86 |
+
settings_path = Path("metagpt/ext/spo/settings")
|
87 |
+
existing_templates = [f.stem for f in settings_path.glob("*.yaml")]
|
88 |
+
|
89 |
+
template_mode = st.radio("Template Mode", ["Use Existing", "Create New"])
|
90 |
+
|
91 |
+
if template_mode == "Use Existing":
|
92 |
+
template_name = st.selectbox("Select Template", existing_templates)
|
93 |
+
else:
|
94 |
+
template_name = st.text_input("New Template Name")
|
95 |
+
if template_name and not template_name.endswith(".yaml"):
|
96 |
+
template_name = f"{template_name}"
|
97 |
+
|
98 |
+
# LLM Settings
|
99 |
+
st.subheader("LLM Settings")
|
100 |
+
opt_model = st.selectbox(
|
101 |
+
"Optimization Model", ["claude-3-5-sonnet-20240620", "gpt-4o", "gpt-4o-mini", "deepseek-chat"], index=0
|
102 |
+
)
|
103 |
+
opt_temp = st.slider("Optimization Temperature", 0.0, 1.0, 0.7)
|
104 |
+
|
105 |
+
eval_model = st.selectbox(
|
106 |
+
"Evaluation Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0
|
107 |
+
)
|
108 |
+
eval_temp = st.slider("Evaluation Temperature", 0.0, 1.0, 0.3)
|
109 |
+
|
110 |
+
exec_model = st.selectbox(
|
111 |
+
"Execution Model", ["gpt-4o-mini", "claude-3-5-sonnet-20240620", "gpt-4o", "deepseek-chat"], index=0
|
112 |
+
)
|
113 |
+
exec_temp = st.slider("Execution Temperature", 0.0, 1.0, 0.0)
|
114 |
+
|
115 |
+
# Optimizer Settings
|
116 |
+
st.subheader("Optimizer Settings")
|
117 |
+
initial_round = st.number_input("Initial Round", 1, 100, 1)
|
118 |
+
max_rounds = st.number_input("Maximum Rounds", 1, 100, 10)
|
119 |
+
|
120 |
+
# Main content area
|
121 |
+
st.header("Template Configuration")
|
122 |
+
|
123 |
+
if template_name:
|
124 |
+
template_path = settings_path / f"{template_name}.yaml"
|
125 |
+
template_data = load_yaml_template(template_path)
|
126 |
+
|
127 |
+
if "current_template" not in st.session_state or st.session_state.current_template != template_name:
|
128 |
+
st.session_state.current_template = template_name
|
129 |
+
st.session_state.qas = template_data.get("qa", [])
|
130 |
+
|
131 |
+
# Edit template sections
|
132 |
+
prompt = st.text_area("Prompt", template_data.get("prompt", ""), height=100)
|
133 |
+
requirements = st.text_area("Requirements", template_data.get("requirements", ""), height=100)
|
134 |
+
|
135 |
+
# qa section
|
136 |
+
st.subheader("Q&A Examples")
|
137 |
+
|
138 |
+
# Add new qa button
|
139 |
+
if st.button("Add New Q&A"):
|
140 |
+
st.session_state.qas.append({"question": "", "answer": ""})
|
141 |
+
|
142 |
+
# Edit qas
|
143 |
+
new_qas = []
|
144 |
+
for i in range(len(st.session_state.qas)):
|
145 |
+
st.markdown(f"**QA #{i + 1}**")
|
146 |
+
col1, col2, col3 = st.columns([45, 45, 10])
|
147 |
+
|
148 |
+
with col1:
|
149 |
+
question = st.text_area(
|
150 |
+
f"Question {i + 1}", st.session_state.qas[i].get("question", ""), key=f"q_{i}", height=100
|
151 |
+
)
|
152 |
+
with col2:
|
153 |
+
answer = st.text_area(
|
154 |
+
f"Answer {i + 1}", st.session_state.qas[i].get("answer", ""), key=f"a_{i}", height=100
|
155 |
+
)
|
156 |
+
with col3:
|
157 |
+
if st.button("🗑️", key=f"delete_{i}"):
|
158 |
+
st.session_state.qas.pop(i)
|
159 |
+
st.rerun()
|
160 |
+
|
161 |
+
new_qas.append({"question": question, "answer": answer})
|
162 |
+
|
163 |
+
# Save template button
|
164 |
+
if st.button("Save Template"):
|
165 |
+
new_template_data = {"prompt": prompt, "requirements": requirements, "count": None, "qa": new_qas}
|
166 |
+
|
167 |
+
save_yaml_template(template_path, new_template_data)
|
168 |
+
|
169 |
+
st.session_state.qas = new_qas
|
170 |
+
st.success(f"Template saved to {template_path}")
|
171 |
+
|
172 |
+
st.subheader("Current Template Preview")
|
173 |
+
preview_data = {"qa": new_qas, "requirements": requirements, "prompt": prompt}
|
174 |
+
st.code(yaml.dump(preview_data, allow_unicode=True), language="yaml")
|
175 |
+
|
176 |
+
st.subheader("Optimization Logs")
|
177 |
+
log_container = st.empty()
|
178 |
+
|
179 |
+
class StreamlitSink:
|
180 |
+
def write(self, message):
|
181 |
+
current_logs = st.session_state.get("logs", [])
|
182 |
+
current_logs.append(message.strip())
|
183 |
+
st.session_state.logs = current_logs
|
184 |
+
|
185 |
+
log_container.code("\n".join(current_logs), language="plaintext")
|
186 |
+
|
187 |
+
streamlit_sink = StreamlitSink()
|
188 |
+
_logger.remove()
|
189 |
+
|
190 |
+
def prompt_optimizer_filter(record):
|
191 |
+
return "optimizer" in record["name"].lower()
|
192 |
+
|
193 |
+
_logger.add(
|
194 |
+
streamlit_sink.write,
|
195 |
+
format="{time:YYYY-MM-DD HH:mm:ss.SSS} | {level: <8} | {name}:{function}:{line} - {message}",
|
196 |
+
filter=prompt_optimizer_filter,
|
197 |
+
)
|
198 |
+
_logger.add(METAGPT_ROOT / "logs/{time:YYYYMMDD}.txt", level="DEBUG")
|
199 |
+
|
200 |
+
# Start optimization button
|
201 |
+
if st.button("Start Optimization"):
|
202 |
+
try:
|
203 |
+
# Initialize LLM
|
204 |
+
SPO_LLM.initialize(
|
205 |
+
optimize_kwargs={"model": opt_model, "temperature": opt_temp},
|
206 |
+
evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
|
207 |
+
execute_kwargs={"model": exec_model, "temperature": exec_temp},
|
208 |
+
)
|
209 |
+
|
210 |
+
# Create optimizer instance
|
211 |
+
optimizer = PromptOptimizer(
|
212 |
+
optimized_path="workspace",
|
213 |
+
initial_round=initial_round,
|
214 |
+
max_rounds=max_rounds,
|
215 |
+
template=f"{template_name}.yaml",
|
216 |
+
name=template_name,
|
217 |
+
)
|
218 |
+
|
219 |
+
# Run optimization with progress bar
|
220 |
+
with st.spinner("Optimizing prompts..."):
|
221 |
+
optimizer.optimize()
|
222 |
+
|
223 |
+
st.success("Optimization completed!")
|
224 |
+
|
225 |
+
st.header("Optimization Results")
|
226 |
+
|
227 |
+
prompt_path = optimizer.root_path / "prompts"
|
228 |
+
result_data = optimizer.data_utils.load_results(prompt_path)
|
229 |
+
|
230 |
+
st.session_state.optimization_results = result_data
|
231 |
+
|
232 |
+
except Exception as e:
|
233 |
+
st.error(f"An error occurred: {str(e)}")
|
234 |
+
_logger.error(f"Error during optimization: {str(e)}")
|
235 |
+
|
236 |
+
if st.session_state.optimization_results:
|
237 |
+
st.header("Optimization Results")
|
238 |
+
display_optimization_results(st.session_state.optimization_results)
|
239 |
+
|
240 |
+
st.markdown("---")
|
241 |
+
st.subheader("Test Optimized Prompt")
|
242 |
+
col1, col2 = st.columns(2)
|
243 |
+
|
244 |
+
with col1:
|
245 |
+
test_prompt = st.text_area("Optimized Prompt", value="", height=200, key="test_prompt")
|
246 |
+
|
247 |
+
with col2:
|
248 |
+
test_question = st.text_area("Your Question", value="", height=200, key="test_question")
|
249 |
+
|
250 |
+
if st.button("Test Prompt"):
|
251 |
+
if test_prompt and test_question:
|
252 |
+
try:
|
253 |
+
with st.spinner("Generating response..."):
|
254 |
+
SPO_LLM.initialize(
|
255 |
+
optimize_kwargs={"model": opt_model, "temperature": opt_temp},
|
256 |
+
evaluate_kwargs={"model": eval_model, "temperature": eval_temp},
|
257 |
+
execute_kwargs={"model": exec_model, "temperature": exec_temp},
|
258 |
+
)
|
259 |
+
|
260 |
+
llm = SPO_LLM.get_instance()
|
261 |
+
messages = [{"role": "user", "content": f"{test_prompt}\n\n{test_question}"}]
|
262 |
+
|
263 |
+
async def get_response():
|
264 |
+
return await llm.responser(request_type=RequestType.EXECUTE, messages=messages)
|
265 |
+
|
266 |
+
loop = asyncio.new_event_loop()
|
267 |
+
asyncio.set_event_loop(loop)
|
268 |
+
try:
|
269 |
+
response = loop.run_until_complete(get_response())
|
270 |
+
finally:
|
271 |
+
loop.close()
|
272 |
+
|
273 |
+
st.subheader("Response:")
|
274 |
+
st.markdown(response)
|
275 |
+
|
276 |
+
except Exception as e:
|
277 |
+
st.error(f"Error generating response: {str(e)}")
|
278 |
+
else:
|
279 |
+
st.warning("Please enter both prompt and question.")
|
280 |
+
|
281 |
+
|
282 |
+
if __name__ == "__main__":
|
283 |
+
main()
|