ar9avg commited on
Commit
4ec680a
·
1 Parent(s): b00a200
Files changed (1) hide show
  1. backend/api/demo.py +69 -56
backend/api/demo.py CHANGED
@@ -176,73 +176,69 @@ async def execute_query_stream(req: ExecuteQueryRequest):
176
  # Initial generate action
177
  action = Action(repair_action="generate")
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  for attempt in range(1, max_attempts + 1):
180
  yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
181
 
182
  ep = env._episode # type: ignore[union-attr]
183
  ep.attempt_number = attempt
184
 
185
- # Generate SQL with streaming
186
- from env.sql_env import _make_client, _MODEL
187
- from openai import AsyncOpenAI
188
-
189
- if attempt == 1 or ep.current_sql is None:
190
- system_prompt = BASE_SYSTEM_PROMPT
191
- # Include previous wrong SQL if user retried after marking wrong
192
- prev_context = ""
193
- if req.previousSql:
194
- prev_context = (
195
- f"\nNOTE: A previous attempt generated the following SQL which was marked INCORRECT:\n"
196
- f"```sql\n{req.previousSql}\n```\n"
197
- f"You MUST try a completely different approach.\n"
198
- )
199
- user_msg = (
200
- f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n"
201
- f"{prev_context}\n"
202
- "Write a SQL query to answer this question."
203
- )
204
- else:
205
- from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
206
-
207
- # Bandit selects action
208
- if ep.current_features is not None:
209
- repair_enum, scores = env._bandit.select_action(ep.current_features)
210
- ucb_scores = {
211
- REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
212
- for i in range(len(scores))
213
- }
214
- action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
215
- yield {"data": json.dumps({
216
- "type": "rl_action",
217
- "action": action.repair_action,
218
- "ucb_scores": ucb_scores,
219
- })}
220
- else:
221
- repair_enum = RepairAction.REWRITE_FULL
222
- action = Action(repair_action="rewrite_full")
223
-
224
- suffix = get_repair_system_suffix(repair_enum)
225
- offending = extract_offending_token(ep.error_message or "")
226
- ctx = RepairContext(
227
- schema=obs.schema_info,
228
- question=req.question,
229
- failing_sql=ep.current_sql or "",
230
- error_message=ep.error_message or "",
231
- offending_token=offending,
232
- )
233
- system_prompt = BASE_SYSTEM_PROMPT + suffix
234
- user_msg = build_repair_user_message(repair_enum, ctx)
235
-
236
- # Stream SQL generation
237
  client = _make_client()
238
  chunks: list[str] = []
239
  try:
240
  stream = await client.chat.completions.create(
241
  model=_MODEL,
242
- messages=[
243
- {"role": "system", "content": system_prompt},
244
- {"role": "user", "content": user_msg},
245
- ],
246
  stream=True,
247
  temperature=0.1,
248
  )
@@ -385,6 +381,23 @@ async def execute_query_stream(req: ExecuteQueryRequest):
385
  })}
386
  done = True
387
  break
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
  total_reward = compute_episode_reward(all_step_rewards, success)
390
 
 
176
  # Initial generate action
177
  action = Action(repair_action="generate")
178
 
179
+ from env.sql_env import _make_client, _MODEL
180
+ from rl.repair_strategies import RepairContext, get_repair_system_suffix, build_repair_user_message
181
+
182
+ # Build initial user message (includes previous-wrong-SQL context if retrying)
183
+ prev_context = ""
184
+ if req.previousSql:
185
+ prev_context = (
186
+ f"\nNOTE: A previous session generated the following SQL which was marked INCORRECT:\n"
187
+ f"```sql\n{req.previousSql}\n```\n"
188
+ f"You MUST try a completely different approach.\n"
189
+ )
190
+ initial_user_msg = (
191
+ f"Schema:\n{obs.schema_info}\n\nQuestion: {req.question}\n"
192
+ f"{prev_context}\n"
193
+ "Write a SQL query to answer this question."
194
+ )
195
+
196
+ # Multi-turn conversation — grows with each failed attempt so the LLM
197
+ # sees its own history and doesn't repeat the same mistake.
198
+ conversation: list[dict] = [
199
+ {"role": "system", "content": BASE_SYSTEM_PROMPT},
200
+ {"role": "user", "content": initial_user_msg},
201
+ ]
202
+
203
  for attempt in range(1, max_attempts + 1):
204
  yield {"data": json.dumps({"type": "attempt_start", "attempt": attempt})}
205
 
206
  ep = env._episode # type: ignore[union-attr]
207
  ep.attempt_number = attempt
208
 
209
+ # On repair attempts, update system prompt with RL-selected repair suffix
210
+ if attempt > 1 and ep.current_features is not None:
211
+ repair_enum, scores = env._bandit.select_action(ep.current_features)
212
+ ucb_scores = {
213
+ REPAIR_ACTION_NAMES[RepairAction(i)]: round(scores[i], 4)
214
+ for i in range(len(scores))
215
+ }
216
+ action = Action(repair_action=REPAIR_ACTION_NAMES[repair_enum])
217
+ yield {"data": json.dumps({
218
+ "type": "rl_action",
219
+ "action": action.repair_action,
220
+ "ucb_scores": ucb_scores,
221
+ })}
222
+ # Update system prompt with repair-specific guidance
223
+ conversation[0] = {
224
+ "role": "system",
225
+ "content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
226
+ }
227
+ elif attempt > 1:
228
+ repair_enum = RepairAction.REWRITE_FULL
229
+ action = Action(repair_action="rewrite_full")
230
+ conversation[0] = {
231
+ "role": "system",
232
+ "content": BASE_SYSTEM_PROMPT + get_repair_system_suffix(repair_enum),
233
+ }
234
+
235
+ # Stream SQL generation using the full conversation history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
  client = _make_client()
237
  chunks: list[str] = []
238
  try:
239
  stream = await client.chat.completions.create(
240
  model=_MODEL,
241
+ messages=conversation,
 
 
 
242
  stream=True,
243
  temperature=0.1,
244
  )
 
381
  })}
382
  done = True
383
  break
384
+ else:
385
+ # Append failed attempt to conversation so the next attempt has full history.
386
+ # This prevents the LLM from repeating the same mistake on subsequent tries.
387
+ conversation.append({"role": "assistant", "content": generated_sql})
388
+ if error:
389
+ offending = extract_offending_token(error)
390
+ feedback_msg = (
391
+ f"That SQL failed with this error:\n{error}\n"
392
+ + (f"Problematic token: '{offending}'\n" if offending else "")
393
+ + "Please fix the SQL. Do NOT repeat the same mistake."
394
+ )
395
+ else:
396
+ feedback_msg = (
397
+ "That SQL ran but returned incorrect or empty results. "
398
+ "Please try a completely different approach."
399
+ )
400
+ conversation.append({"role": "user", "content": feedback_msg})
401
 
402
  total_reward = compute_episode_reward(all_step_rewards, success)
403