Spaces:
Sleeping
Sleeping
fix
Browse files- backend/api/demo.py +69 -56
backend/api/demo.py
CHANGED
|
@@ -176,73 +176,69 @@ async def execute_query_stream(req: ExecuteQueryRequest):
|
|
| 176 |
# Initial generate action
|
| 177 |
action = Action(repair_action="generate")
|
| 178 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 179 |
for attempt in range(1, max_attempts + 1):
|
| 180 |
yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
|
| 181 |
|
| 182 |
ep = env._episode # type: ignore[union-attr]
|
| 183 |
ep.attempt_number = attempt
|
| 184 |
|
| 185 |
-
#
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
|
| 202 |
-
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
|
| 206 |
-
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
for i in range(len(scores))
|
| 213 |
-
}
|
| 214 |
-
action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
|
| 215 |
-
yield {"data": json.dumps({
|
| 216 |
-
"type": "rl_action",
|
| 217 |
-
"action": action.repair_action,
|
| 218 |
-
"ucb_scores": ucb_scores,
|
| 219 |
-
})}
|
| 220 |
-
else:
|
| 221 |
-
repair_enum = RepairAction.REWRITE_FULL
|
| 222 |
-
action = Action(repair_action="rewrite_full")
|
| 223 |
-
|
| 224 |
-
suffix = get_repair_system_suffix(repair_enum)
|
| 225 |
-
offending = extract_offending_token(ep.error_message or "")
|
| 226 |
-
ctx = RepairContext(
|
| 227 |
-
schema=obs.schema_info,
|
| 228 |
-
question=req.question,
|
| 229 |
-
failing_sql=ep.current_sql or "",
|
| 230 |
-
error_message=ep.error_message or "",
|
| 231 |
-
offending_token=offending,
|
| 232 |
-
)
|
| 233 |
-
system_prompt = BASE_SYSTEM_PROMPT + suffix
|
| 234 |
-
user_msg = build_repair_user_message(repair_enum, ctx)
|
| 235 |
-
|
| 236 |
-
# Stream SQL generation
|
| 237 |
client = _make_client()
|
| 238 |
chunks: list[str] = []
|
| 239 |
try:
|
| 240 |
stream = await client.chat.completions.create(
|
| 241 |
model=_MODEL,
|
| 242 |
-
messages=
|
| 243 |
-
{"role": "system", "content": system_prompt},
|
| 244 |
-
{"role": "user", "content": user_msg},
|
| 245 |
-
],
|
| 246 |
stream=True,
|
| 247 |
temperature=0.1,
|
| 248 |
)
|
|
@@ -385,6 +381,23 @@ async def execute_query_stream(req: ExecuteQueryRequest):
|
|
| 385 |
})}
|
| 386 |
done = True
|
| 387 |
break
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
| 389 |
total_reward = compute_episode_reward(all_step_rewards, success)
|
| 390 |
|
|
|
|
| 176 |
# Initial generate action
|
| 177 |
action = Action(repair_action="generate")
|
| 178 |
|
| 179 |
+
from env.sql_env import _make_client, _MODEL
|
| 180 |
+
from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
|
| 181 |
+
|
| 182 |
+
# Build initial user message (includes previous-wrong-SQL context if retrying)
|
| 183 |
+
prev_context = ""
|
| 184 |
+
if req.previousSql:
|
| 185 |
+
prev_context = (
|
| 186 |
+
f"\nNOTE: A previous session generated the following SQL which was marked INCORRECT:\n"
|
| 187 |
+
f"```sql\n{req.previousSql}\n```\n"
|
| 188 |
+
f"You MUST try a completely different approach.\n"
|
| 189 |
+
)
|
| 190 |
+
initial_user_msg = (
|
| 191 |
+
f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n"
|
| 192 |
+
f"{prev_context}\n"
|
| 193 |
+
"Write a SQL query to answer this question."
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Multi-turn conversation — grows with each failed attempt so the LLM
|
| 197 |
+
# sees its own history and doesn't repeat the same mistake.
|
| 198 |
+
conversation: list[dict] = [
|
| 199 |
+
{"role": "system", "content": BASE_SYSTEM_PROMPT},
|
| 200 |
+
{"role": "user", "content": initial_user_msg},
|
| 201 |
+
]
|
| 202 |
+
|
| 203 |
for attempt in range(1, max_attempts + 1):
|
| 204 |
yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
|
| 205 |
|
| 206 |
ep = env._episode # type: ignore[union-attr]
|
| 207 |
ep.attempt_number = attempt
|
| 208 |
|
| 209 |
+
# On repair attempts, update system prompt with RL-selected repair suffix
|
| 210 |
+
if attempt > 1 and ep.current_features is not None:
|
| 211 |
+
repair_enum, scores = env._bandit.select_action(ep.current_features)
|
| 212 |
+
ucb_scores = {
|
| 213 |
+
REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
|
| 214 |
+
for i in range(len(scores))
|
| 215 |
+
}
|
| 216 |
+
action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
|
| 217 |
+
yield {"data": json.dumps({
|
| 218 |
+
"type": "rl_action",
|
| 219 |
+
"action": action.repair_action,
|
| 220 |
+
"ucb_scores": ucb_scores,
|
| 221 |
+
})}
|
| 222 |
+
# Update system prompt with repair-specific guidance
|
| 223 |
+
conversation[0] = {
|
| 224 |
+
"role": "system",
|
| 225 |
+
"content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
|
| 226 |
+
}
|
| 227 |
+
elif attempt > 1:
|
| 228 |
+
repair_enum = RepairAction.REWRITE_FULL
|
| 229 |
+
action = Action(repair_action="rewrite_full")
|
| 230 |
+
conversation[0] = {
|
| 231 |
+
"role": "system",
|
| 232 |
+
"content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
|
| 233 |
+
}
|
| 234 |
+
|
| 235 |
+
# Stream SQL generation using the full conversation history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 236 |
client = _make_client()
|
| 237 |
chunks: list[str] = []
|
| 238 |
try:
|
| 239 |
stream = await client.chat.completions.create(
|
| 240 |
model=_MODEL,
|
| 241 |
+
messages=conversation,
|
|
|
|
|
|
|
|
|
|
| 242 |
stream=True,
|
| 243 |
temperature=0.1,
|
| 244 |
)
|
|
|
|
| 381 |
})}
|
| 382 |
done = True
|
| 383 |
break
|
| 384 |
+
else:
|
| 385 |
+
# Append failed attempt to conversation so the next attempt has full history.
|
| 386 |
+
# This prevents the LLM from repeating the same mistake on subsequent tries.
|
| 387 |
+
conversation.append({"role": "assistant", "content": generated_sql})
|
| 388 |
+
if error:
|
| 389 |
+
offending = extract_offending_token(error)
|
| 390 |
+
feedback_msg = (
|
| 391 |
+
f"That SQL failed with this error:\n{error}\n"
|
| 392 |
+
+ (f"Problematic token: '{offending}'\n" if offending else "")
|
| 393 |
+
+ "Please fix the SQL. Do NOT repeat the same mistake."
|
| 394 |
+
)
|
| 395 |
+
else:
|
| 396 |
+
feedback_msg = (
|
| 397 |
+
"That SQL ran but returned incorrect or empty results. "
|
| 398 |
+
"Please try a completely different approach."
|
| 399 |
+
)
|
| 400 |
+
conversation.append({"role": "user", "content": feedback_msg})
|
| 401 |
|
| 402 |
total_reward = compute_episode_reward(all_step_rewards, success)
|
| 403 |
|