DataBoySu commited on
Commit
9670629
·
1 Parent(s): acfb96b

infernece

Browse files
Files changed (1) hide show
  1. inference.py +300 -82
inference.py CHANGED
@@ -3,21 +3,20 @@ AML Investigator - Baseline Inference Script
3
  Loops through all 3 tasks to satisfy the Phase 2 Validator.
4
  """
5
  import asyncio
6
- import os
7
  import json
8
- import textwrap
9
- import sys
10
  import re
11
- from typing import List, Optional
 
 
 
12
  from openai import OpenAI
13
 
14
- # Adjust the import based on your openenv server setup
15
- # If running locally without docker wrapper for validation, you might need to import your Env directly
16
  from server.AML_env_environment import AmlEnvironment
17
  from models import AmlAction
18
 
19
 
20
- API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:1234/v1"
21
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
22
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
23
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
@@ -27,31 +26,49 @@ TASKS = ["aml_easy", "aml_medium", "aml_hard"]
27
  BENCHMARK = "aml_investigator"
28
  MAX_STEPS = 25
29
 
 
 
 
 
 
30
  SYSTEM_PROMPT = textwrap.dedent(
31
  """
32
  You are a Tier 1 AML Compliance Investigator.
33
  You must investigate the provided alert by querying the bank's internal APIs.
34
-
35
  You have a strict API budget. Be efficient.
36
  Respond with EXACTLY ONE valid JSON object representing your action. Do not include markdown formatting or explanations.
37
-
38
  Available Action JSON Schemas:
39
  1. {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
40
  2. {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
41
  3. {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
42
  4. {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} (Use "CLEAR" for False Positives with empty evidence_links).
43
 
44
- Token-saving style rule:
45
- - Think in caveman style (short, simple words).
46
- - Never output prose. Output JSON only.
 
 
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- Data rule:
49
  - get_kyc_record must use ENT-XXXX only, never ACC-XXXX.
 
 
50
  """
51
  ).strip()
52
 
53
- FALLBACK_ACTION_JSON = '{"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}'
54
-
55
 
56
  def _extract_text_from_chat_completion(completion: object) -> str:
57
  choices = getattr(completion, "choices", None) or []
@@ -74,6 +91,10 @@ def _extract_text_from_chat_completion(completion: object) -> str:
74
  text_val = item.get("text")
75
  if isinstance(text_val, str):
76
  chunks.append(text_val)
 
 
 
 
77
  merged = "".join(chunks).strip()
78
  if merged:
79
  return merged
@@ -94,6 +115,10 @@ def _extract_text_from_responses_api(response: object) -> str:
94
  text_val = getattr(part, "text", None)
95
  if isinstance(text_val, str):
96
  chunks.append(text_val)
 
 
 
 
97
 
98
  merged = "".join(chunks).strip()
99
  if merged:
@@ -131,17 +156,107 @@ def _coerce_json_object(raw_text: str) -> str:
131
  return text
132
 
133
 
134
- def _build_recovery_action_from_obs(obs_dict: dict) -> dict:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
  """Use a non-terminal fallback action when model output is malformed."""
136
  alert = str(obs_dict.get("alert_details", "") or "")
137
  match = re.search(r"ACC-\d+", alert)
138
  if match:
 
 
 
139
  return {
140
  "action": {
141
  "action_type": "query_transactions",
142
- "account_id": match.group(0),
143
  "limit": 10,
144
- "offset": 0,
145
  }
146
  }
147
  return {
@@ -153,8 +268,32 @@ def _build_recovery_action_from_obs(obs_dict: dict) -> dict:
153
  }
154
 
155
 
156
- def _ensure_valid_action_json(raw_text: str, obs_dict: dict) -> str:
157
- """Guarantee a valid action JSON string for downstream parsing."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
  candidate = _coerce_json_object(raw_text)
159
  try:
160
  payload = json.loads(candidate)
@@ -166,15 +305,11 @@ def _ensure_valid_action_json(raw_text: str, obs_dict: dict) -> str:
166
  action_type = action.get("action_type")
167
  if not isinstance(action_type, str):
168
  raise ValueError("missing 'action_type' string")
 
169
  return json.dumps(payload, ensure_ascii=True)
170
- except Exception as exc:
171
- recovery_json = _build_recovery_action_from_obs(obs_dict)
172
- print(
173
- f"[DEBUG] Non-JSON/invalid model action; using recovery action ({exc})",
174
- file=sys.stderr,
175
- flush=True,
176
- )
177
- return json.dumps(recovery_json, ensure_ascii=True)
178
 
179
  def log_start(task: str, env: str, model: str) -> None:
180
  print(f"[START] task={task} env={env} model={model}", flush=True)
@@ -185,6 +320,7 @@ def log_step(step: int, action: str, reward: float, done: bool, error: Optional[
185
  done_val = str(done).lower()
186
  print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
187
 
 
188
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
189
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
190
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
@@ -201,10 +337,37 @@ def log_thought(step: int, thought: Optional[object]) -> None:
201
  compact = compact.replace("\n", " ").strip()
202
  print(f"[THOUGHT] step={step} thought={compact}", file=sys.stderr, flush=True)
203
 
204
- def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str:
205
- history_block = "\n".join(history[-5:]) if history else "No previous steps."
206
- user_prompt = f"Observation:\n{json.dumps(obs_dict, indent=2)}\n\nHistory:\n{history_block}\n\nProvide your next JSON action:"
207
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  try:
209
  completion = client.chat.completions.create(
210
  model=MODEL_NAME,
@@ -213,69 +376,81 @@ def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str
213
  {"role": "user", "content": user_prompt},
214
  ],
215
  temperature=0.0,
216
- max_tokens=1000,
217
- response_format={"type": "json_object"},
218
  )
219
- return _ensure_valid_action_json(_extract_text_from_chat_completion(completion), obs_dict)
 
 
 
 
220
  except Exception as chat_exc:
221
- # Retry via Responses API for OpenAI-compatible providers that do not
222
- # populate chat.completions choices consistently.
223
- try:
224
- response = client.responses.create(
225
- model=MODEL_NAME,
226
- instructions=SYSTEM_PROMPT,
227
- input=user_prompt,
228
- max_output_tokens=1000,
229
- )
230
- return _ensure_valid_action_json(_extract_text_from_responses_api(response), obs_dict)
231
- except Exception as responses_exc:
232
- try:
233
- completion = client.completions.create(
234
- model=MODEL_NAME,
235
- prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}",
236
- temperature=0.0,
237
- max_tokens=200,
238
- )
239
- return _ensure_valid_action_json(_extract_text_from_completions_api(completion), obs_dict)
240
- except Exception as completions_exc:
241
- print(
242
- (
243
- "[DEBUG] Model request failed: "
244
- f"chat={chat_exc}; responses={responses_exc}; completions={completions_exc}"
245
- ),
246
- file=sys.stderr,
247
- flush=True,
248
- )
249
- return _ensure_valid_action_json(FALLBACK_ACTION_JSON, obs_dict)
 
 
 
 
 
 
 
250
 
251
  async def main() -> None:
252
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
253
-
254
- # Initialize your environment natively for the baseline script
255
  env = AmlEnvironment()
256
 
257
  for task_name in TASKS:
258
- history: List[str] = []
259
  rewards: List[float] = []
260
  steps_taken = 0
261
  score = 0.0
262
  success = False
263
  had_parse_error = False
 
 
264
 
265
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
266
 
267
  try:
268
  obs = env.reset(task=task_name)
269
-
270
  for step in range(1, MAX_STEPS + 1):
271
  if obs.done:
272
  break
273
 
274
  obs_dict = obs.model_dump()
275
- action_str = get_model_message(client, obs_dict, history)
276
-
277
- # Parse LLM string to Pydantic Model
 
278
  action_for_log = action_str
 
279
  try:
280
  clean_str = _coerce_json_object(action_str)
281
  action_json = json.loads(clean_str)
@@ -285,32 +460,74 @@ async def main() -> None:
285
  thought_for_log = f"do {action_type} now"
286
  log_thought(step=step, thought=thought_for_log)
287
  action_obj = AmlAction.model_validate(action_json)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
288
  error = None
289
  except Exception as e:
290
- # Errors are data! If the LLM writes bad JSON, we catch it and force a dummy action
291
- # so the environment can return a schema error to the LLM.
292
  had_parse_error = True
293
  error = f"JSON Parse/Schema Error: {str(e)}"
294
  log_thought(step=step, thought="parse fail; use recovery action")
295
- recovery_json = _build_recovery_action_from_obs(obs_dict)
296
- action_obj = AmlAction.model_validate(recovery_json)
297
- action_for_log = json.dumps(recovery_json, ensure_ascii=True)
 
 
 
 
 
 
 
 
 
298
 
299
  obs = env.step(action_obj)
300
-
301
  reward = obs.reward or 0.0
302
  done = obs.done
303
 
304
  rewards.append(reward)
305
  steps_taken = step
306
-
307
- log_step(step=step, action=action_for_log.replace('\n', ''), reward=reward, done=done, error=error)
308
- history.append(f"Step {step}: Action: {action_str} -> Result: {obs.last_action_result} | Error: {obs.error_message}")
 
 
 
 
 
 
 
 
 
 
309
 
310
  if done:
311
  break
312
 
313
- # Keep score in open interval (0,1) and avoid false positives on parse failures.
314
  if had_parse_error or obs.error_message:
315
  score = 0.05
316
  elif "submit_decision" in (obs.last_action or ""):
@@ -323,5 +540,6 @@ async def main() -> None:
323
  finally:
324
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
325
 
 
326
  if __name__ == "__main__":
327
- asyncio.run(main())
 
3
  Loops through all 3 tasks to satisfy the Phase 2 Validator.
4
  """
5
  import asyncio
 
6
  import json
7
+ import os
 
8
  import re
9
+ import sys
10
+ import textwrap
11
+ from typing import Any, Dict, List, Optional, Tuple
12
+
13
  from openai import OpenAI
14
 
 
 
15
  from server.AML_env_environment import AmlEnvironment
16
  from models import AmlAction
17
 
18
 
19
+ API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1"
20
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
21
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
 
26
  BENCHMARK = "aml_investigator"
27
  MAX_STEPS = 25
28
 
29
+ OBS_RESULT_MAX_ITEMS = 8
30
+ HISTORY_MAX_STEPS = 3
31
+ HISTORY_MAX_CHARS = 1600
32
+ TEXT_CLIP_CHARS = 320
33
+
34
  SYSTEM_PROMPT = textwrap.dedent(
35
  """
36
  You are a Tier 1 AML Compliance Investigator.
37
  You must investigate the provided alert by querying the bank's internal APIs.
38
+
39
  You have a strict API budget. Be efficient.
40
  Respond with EXACTLY ONE valid JSON object representing your action. Do not include markdown formatting or explanations.
41
+
42
  Available Action JSON Schemas:
43
  1. {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
44
  2. {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
45
  3. {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
46
  4. {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} (Use "CLEAR" for False Positives with empty evidence_links).
47
 
48
+ Required top-level JSON format:
49
+ {
50
+ "thought": {
51
+ "observation": "...",
52
+ "plan": "...",
53
+ "action": "..."
54
+ },
55
+ "action": {...}
56
+ }
57
+
58
+ Thought rules:
59
+ - Use caveman style: short, simple, low-token wording.
60
+ - Keep thought informative but brief.
61
+ - observation = what clue found now.
62
+ - plan = next investigation goal.
63
+ - action = exact tool call you will make now.
64
 
65
+ Data rules:
66
  - get_kyc_record must use ENT-XXXX only, never ACC-XXXX.
67
+ - submit_decision only when evidence is enough; else keep investigating.
68
+ - Use only the alert, the current observation, and the recent history shown here.
69
  """
70
  ).strip()
71
 
 
 
72
 
73
  def _extract_text_from_chat_completion(completion: object) -> str:
74
  choices = getattr(completion, "choices", None) or []
 
91
  text_val = item.get("text")
92
  if isinstance(text_val, str):
93
  chunks.append(text_val)
94
+ else:
95
+ text_val = getattr(item, "text", None)
96
+ if isinstance(text_val, str):
97
+ chunks.append(text_val)
98
  merged = "".join(chunks).strip()
99
  if merged:
100
  return merged
 
115
  text_val = getattr(part, "text", None)
116
  if isinstance(text_val, str):
117
  chunks.append(text_val)
118
+ elif isinstance(part, dict):
119
+ maybe_text = part.get("text")
120
+ if isinstance(maybe_text, str):
121
+ chunks.append(maybe_text)
122
 
123
  merged = "".join(chunks).strip()
124
  if merged:
 
156
  return text
157
 
158
 
159
+ def _clip_text(value: Any, max_chars: int = TEXT_CLIP_CHARS) -> str:
160
+ text = str(value).replace("\n", " ").strip()
161
+ if len(text) <= max_chars:
162
+ return text
163
+ return text[: max_chars - 3] + "..."
164
+
165
+
166
+ def _compact_record(record: Dict[str, Any]) -> Dict[str, Any]:
167
+ keep_keys = [
168
+ "txn_id",
169
+ "timestamp",
170
+ "sender_account",
171
+ "receiver_account",
172
+ "amount",
173
+ "memo_text",
174
+ "account_id",
175
+ "owner_entity_id",
176
+ "status",
177
+ "entity_id",
178
+ "name",
179
+ "type",
180
+ "registration_address",
181
+ "directors",
182
+ ]
183
+ compact: Dict[str, Any] = {}
184
+ for key in keep_keys:
185
+ if key not in record:
186
+ continue
187
+ value = record.get(key)
188
+ if key == "directors" and isinstance(value, list):
189
+ compact[key] = value[:4]
190
+ if len(value) > 4:
191
+ compact["directors_truncated"] = len(value) - 4
192
+ continue
193
+ if isinstance(value, str):
194
+ compact[key] = _clip_text(value, max_chars=180)
195
+ else:
196
+ compact[key] = value
197
+ return compact
198
+
199
+
200
+ def _compact_action_result(last_action: Optional[str], value: Any) -> Any:
201
+ if value is None:
202
+ return None
203
+ if isinstance(value, list):
204
+ items = []
205
+ for item in value[:OBS_RESULT_MAX_ITEMS]:
206
+ if isinstance(item, dict):
207
+ items.append(_compact_record(item))
208
+ else:
209
+ items.append(_clip_text(item))
210
+ return {
211
+ "kind": "list",
212
+ "count": len(value),
213
+ "items": items,
214
+ "truncated": len(value) > OBS_RESULT_MAX_ITEMS,
215
+ "source_action": last_action,
216
+ }
217
+ if isinstance(value, dict):
218
+ return _compact_record(value)
219
+ if isinstance(value, str):
220
+ return _clip_text(value, max_chars=420)
221
+ return value
222
+
223
+
224
+ def _build_model_observation(obs_dict: Dict[str, Any]) -> Dict[str, Any]:
225
+ return {
226
+ "alert_details": obs_dict.get("alert_details"),
227
+ "budget_remaining": obs_dict.get("budget_remaining"),
228
+ "last_action": obs_dict.get("last_action"),
229
+ "last_action_result": _compact_action_result(obs_dict.get("last_action"), obs_dict.get("last_action_result")),
230
+ "error_message": _clip_text(obs_dict.get("error_message")) if obs_dict.get("error_message") else None,
231
+ "done": obs_dict.get("done"),
232
+ "reward": obs_dict.get("reward"),
233
+ }
234
+
235
+
236
+ def _render_history(history: List[Dict[str, Any]]) -> str:
237
+ if not history:
238
+ return "No previous steps."
239
+ entries = history[-HISTORY_MAX_STEPS:]
240
+ lines = [json.dumps(item, ensure_ascii=True, separators=(",", ":")) for item in entries]
241
+ while lines and len("\n".join(lines)) > HISTORY_MAX_CHARS:
242
+ lines.pop(0)
243
+ return "\n".join(lines) if lines else "No previous steps."
244
+
245
+
246
+ def _build_recovery_action_from_obs(obs_dict: dict, next_offsets: Dict[str, int]) -> dict:
247
  """Use a non-terminal fallback action when model output is malformed."""
248
  alert = str(obs_dict.get("alert_details", "") or "")
249
  match = re.search(r"ACC-\d+", alert)
250
  if match:
251
+ account_id = match.group(0)
252
+ offset = next_offsets.get(account_id, 0)
253
+ next_offsets[account_id] = offset + 10
254
  return {
255
  "action": {
256
  "action_type": "query_transactions",
257
+ "account_id": account_id,
258
  "limit": 10,
259
+ "offset": offset,
260
  }
261
  }
262
  return {
 
268
  }
269
 
270
 
271
+ def _normalize_thought(payload: Dict[str, Any]) -> None:
272
+ action = payload.get("action") if isinstance(payload.get("action"), dict) else {}
273
+ action_type = action.get("action_type", "unknown")
274
+ if "thought" not in payload or not isinstance(payload.get("thought"), dict):
275
+ payload["thought"] = {
276
+ "observation": "see current clue now.",
277
+ "plan": "find next real link.",
278
+ "action": f"do {action_type} now.",
279
+ }
280
+ return
281
+
282
+ thought = payload["thought"]
283
+ for key, fallback in (
284
+ ("observation", "see clue now."),
285
+ ("plan", "next check key link."),
286
+ ("action", f"do {action_type} now."),
287
+ ):
288
+ value = thought.get(key)
289
+ if not isinstance(value, str) or not value.strip():
290
+ thought[key] = fallback
291
+ else:
292
+ thought[key] = _clip_text(value, max_chars=140)
293
+
294
+
295
+ def _try_validate_action_json(raw_text: str) -> Optional[str]:
296
+ """Return canonical JSON string if valid, else None."""
297
  candidate = _coerce_json_object(raw_text)
298
  try:
299
  payload = json.loads(candidate)
 
305
  action_type = action.get("action_type")
306
  if not isinstance(action_type, str):
307
  raise ValueError("missing 'action_type' string")
308
+ _normalize_thought(payload)
309
  return json.dumps(payload, ensure_ascii=True)
310
+ except Exception:
311
+ return None
312
+
 
 
 
 
 
313
 
314
  def log_start(task: str, env: str, model: str) -> None:
315
  print(f"[START] task={task} env={env} model={model}", flush=True)
 
320
  done_val = str(done).lower()
321
  print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
322
 
323
+
324
  def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
325
  rewards_str = ",".join(f"{r:.2f}" for r in rewards)
326
  print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
 
337
  compact = compact.replace("\n", " ").strip()
338
  print(f"[THOUGHT] step={step} thought={compact}", file=sys.stderr, flush=True)
339
 
340
+
341
+ def get_model_message(
342
+ client: OpenAI,
343
+ obs_dict: dict,
344
+ history: List[Dict[str, Any]],
345
+ next_offsets: Dict[str, int],
346
+ ) -> Tuple[str, bool]:
347
+ model_obs = _build_model_observation(obs_dict)
348
+ history_block = _render_history(history)
349
+ user_prompt = (
350
+ f"Observation:\n{json.dumps(model_obs, ensure_ascii=True, indent=2)}\n\n"
351
+ f"History:\n{history_block}\n\n"
352
+ "Return exactly one JSON object with keys: thought, action."
353
+ )
354
+ parse_errors: List[str] = []
355
+
356
+ try:
357
+ response = client.responses.create(
358
+ model=MODEL_NAME,
359
+ instructions=SYSTEM_PROMPT,
360
+ input=user_prompt,
361
+ max_output_tokens=700,
362
+ )
363
+ raw_text = _extract_text_from_responses_api(response)
364
+ canonical = _try_validate_action_json(raw_text)
365
+ if canonical is not None:
366
+ return canonical, False
367
+ parse_errors.append("responses:invalid_json")
368
+ except Exception as responses_exc:
369
+ parse_errors.append(f"responses:{responses_exc}")
370
+
371
  try:
372
  completion = client.chat.completions.create(
373
  model=MODEL_NAME,
 
376
  {"role": "user", "content": user_prompt},
377
  ],
378
  temperature=0.0,
379
+ max_tokens=700,
 
380
  )
381
+ raw_text = _extract_text_from_chat_completion(completion)
382
+ canonical = _try_validate_action_json(raw_text)
383
+ if canonical is not None:
384
+ return canonical, False
385
+ parse_errors.append("chat:invalid_json")
386
  except Exception as chat_exc:
387
+ parse_errors.append(f"chat:{chat_exc}")
388
+
389
+ try:
390
+ completion = client.completions.create(
391
+ model=MODEL_NAME,
392
+ prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}",
393
+ temperature=0.0,
394
+ max_tokens=280,
395
+ )
396
+ raw_text = _extract_text_from_completions_api(completion)
397
+ canonical = _try_validate_action_json(raw_text)
398
+ if canonical is not None:
399
+ return canonical, False
400
+ parse_errors.append("completions:invalid_json")
401
+ except Exception as completions_exc:
402
+ parse_errors.append(f"completions:{completions_exc}")
403
+
404
+ recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
405
+ print(
406
+ (
407
+ "[DEBUG] Non-JSON/invalid model action; using recovery action "
408
+ f"({'; '.join(parse_errors)})"
409
+ ),
410
+ file=sys.stderr,
411
+ flush=True,
412
+ )
413
+ recovery_payload = {
414
+ "thought": {
415
+ "observation": "model output bad json.",
416
+ "plan": "use safe step. keep investigate.",
417
+ "action": "query alert account next page.",
418
+ },
419
+ "action": recovery_json["action"],
420
+ }
421
+ return json.dumps(recovery_payload, ensure_ascii=True), True
422
+
423
 
424
  async def main() -> None:
425
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
 
 
426
  env = AmlEnvironment()
427
 
428
  for task_name in TASKS:
429
+ history: List[Dict[str, Any]] = []
430
  rewards: List[float] = []
431
  steps_taken = 0
432
  score = 0.0
433
  success = False
434
  had_parse_error = False
435
+ next_offsets: Dict[str, int] = {}
436
+ query_seen_counts: Dict[Tuple[str, int], int] = {}
437
 
438
  log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
439
 
440
  try:
441
  obs = env.reset(task=task_name)
442
+
443
  for step in range(1, MAX_STEPS + 1):
444
  if obs.done:
445
  break
446
 
447
  obs_dict = obs.model_dump()
448
+ action_str, used_recovery = get_model_message(client, obs_dict, history, next_offsets)
449
+ if used_recovery:
450
+ had_parse_error = True
451
+
452
  action_for_log = action_str
453
+ action_payload_for_history: Dict[str, Any] = {}
454
  try:
455
  clean_str = _coerce_json_object(action_str)
456
  action_json = json.loads(clean_str)
 
460
  thought_for_log = f"do {action_type} now"
461
  log_thought(step=step, thought=thought_for_log)
462
  action_obj = AmlAction.model_validate(action_json)
463
+
464
+ action_payload_for_history = action_json.get("action", {}) if isinstance(action_json, dict) else {}
465
+ action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True)
466
+ if action_payload_for_history.get("action_type") == "query_transactions":
467
+ acc = action_payload_for_history.get("account_id")
468
+ offset = int(action_payload_for_history.get("offset", 0))
469
+ limit = int(action_payload_for_history.get("limit", 10))
470
+ if isinstance(acc, str):
471
+ query_key = (acc, offset)
472
+ query_seen_counts[query_key] = query_seen_counts.get(query_key, 0) + 1
473
+ # Hard guardrail: avoid wasting budget on repeated same page.
474
+ if task_name == "aml_hard" and query_seen_counts[query_key] > 2:
475
+ new_offset = max(next_offsets.get(acc, offset + max(limit, 1)), offset + max(limit, 1))
476
+ action_json["action"]["offset"] = new_offset
477
+ action_json["thought"]["plan"] = _clip_text(
478
+ f"repeat page seen. move to next offset {new_offset}.",
479
+ max_chars=120,
480
+ )
481
+ action_json["thought"]["action"] = _clip_text(
482
+ f"query_transactions {acc} offset {new_offset}",
483
+ max_chars=120,
484
+ )
485
+ action_for_log = json.dumps(action_json, ensure_ascii=True)
486
+ action_obj = AmlAction.model_validate(action_json)
487
+ offset = new_offset
488
+ next_offsets[acc] = max(next_offsets.get(acc, 0), offset + max(limit, 1))
489
  error = None
490
  except Exception as e:
 
 
491
  had_parse_error = True
492
  error = f"JSON Parse/Schema Error: {str(e)}"
493
  log_thought(step=step, thought="parse fail; use recovery action")
494
+ recovery_json = _build_recovery_action_from_obs(obs_dict, next_offsets)
495
+ recovery_payload = {
496
+ "thought": {
497
+ "observation": "parse fail now.",
498
+ "plan": "safe step, keep digging.",
499
+ "action": "query alert next page.",
500
+ },
501
+ "action": recovery_json["action"],
502
+ }
503
+ action_obj = AmlAction.model_validate(recovery_payload)
504
+ action_payload_for_history = recovery_payload["action"]
505
+ action_for_log = json.dumps({"action": action_payload_for_history}, ensure_ascii=True)
506
 
507
  obs = env.step(action_obj)
508
+
509
  reward = obs.reward or 0.0
510
  done = obs.done
511
 
512
  rewards.append(reward)
513
  steps_taken = step
514
+
515
+ log_step(step=step, action=action_for_log.replace("\n", ""), reward=reward, done=done, error=error)
516
+ history.append(
517
+ {
518
+ "step": step,
519
+ "action": action_payload_for_history,
520
+ "result": _compact_action_result(obs.last_action, obs.last_action_result),
521
+ "error": _clip_text(obs.error_message) if obs.error_message else None,
522
+ "budget_remaining": obs.budget_remaining,
523
+ }
524
+ )
525
+ if len(history) > 24:
526
+ history = history[-24:]
527
 
528
  if done:
529
  break
530
 
 
531
  if had_parse_error or obs.error_message:
532
  score = 0.05
533
  elif "submit_decision" in (obs.last_action or ""):
 
540
  finally:
541
  log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
542
 
543
+
544
  if __name__ == "__main__":
545
+ asyncio.run(main())