DataBoySu commited on
Commit
1a65601
·
1 Parent(s): 7ae8bca
Files changed (2) hide show
  1. README.md +1 -1
  2. inference.py +117 -155
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: AML Investigator OpenEnv RL Environment
3
  emoji: 🕵️
4
  colorFrom: indigo
5
  colorTo: red
 
1
  ---
2
+ title: Anti Money Laundering RL Env
3
  emoji: 🕵️
4
  colorFrom: indigo
5
  colorTo: red
inference.py CHANGED
@@ -1,27 +1,19 @@
1
- """Baseline inference runner for AML_env.
2
-
3
- The script supports local LM Studio via an OpenAI-compatible base URL and keeps
4
- the multi-task loop expected by the project validator.
5
  """
6
-
7
- from __future__ import annotations
8
-
9
- import json
10
  import os
11
- from pathlib import Path
12
- from typing import Any, Optional
13
-
 
14
  from openai import OpenAI
15
 
16
- try:
17
- from AML_env.client import AmlEnv
18
- from AML_env.models import AmlAction
19
- except Exception:
20
- ROOT_DIR = Path(__file__).resolve().parent
21
- if str(ROOT_DIR) not in os.sys.path:
22
- os.sys.path.insert(0, str(ROOT_DIR))
23
- from client import AmlEnv
24
- from models import AmlAction
25
 
26
 
27
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") or "http://127.0.0.1:1234"
@@ -30,80 +22,43 @@ HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
30
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
31
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:7860")
32
  TASK_NAME = os.getenv("TASK_NAME", "aml_easy")
33
- BENCHMARK = os.getenv("BENCHMARK", "AML_env")
34
- MAX_STEPS = int(os.getenv("MAX_STEPS", "25"))
35
  TASKS = ["aml_easy", "aml_medium", "aml_hard"]
36
-
37
- if not HF_TOKEN:
38
- raise ValueError("HF_TOKEN environment variable is required")
39
-
40
- SYSTEM_PROMPT = (
41
- "You are a Tier 1 AML Compliance Investigator. "
42
- "Return exactly one JSON object with the nested shape {\"action\": {...}}. "
43
- "Allowed action types: query_transactions, search_transactions, get_kyc_record, submit_decision. "
44
- "Do not output markdown, code fences, or explanations."
45
- )
46
-
47
-
48
- def _clean_text(value: str) -> str:
49
- return value.replace("\n", " ").replace("\r", " ").strip()
50
-
51
-
52
- def _format_reward(value: float) -> str:
53
- return f"{value:.2f}"
54
-
55
-
56
- def _format_action(action: AmlAction) -> str:
57
- return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
58
-
59
-
60
- def _format_error(error: Optional[str]) -> str:
61
- return error if error else "null"
62
-
63
 
64
  def log_start(task: str, env: str, model: str) -> None:
65
  print(f"[START] task={task} env={env} model={model}", flush=True)
66
 
67
 
68
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
69
- print(
70
- f"[STEP] step={step} action={action} reward={_format_reward(reward)} done={str(done).lower()} error={_format_error(error)}",
71
- flush=True,
72
- )
73
-
74
-
75
- def log_end(success: bool, steps: int, rewards: list[float]) -> None:
76
- rewards_str = ",".join(_format_reward(r) for r in rewards)
77
- print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
78
-
79
-
80
- def _build_env() -> AmlEnv:
81
- if LOCAL_IMAGE_NAME:
82
- return AmlEnv.from_docker_image(LOCAL_IMAGE_NAME)
83
- return AmlEnv(base_url=ENV_BASE_URL)
84
-
85
-
86
- def _fallback_action() -> AmlAction:
87
- return AmlAction.model_validate(
88
- {
89
- "action": {
90
- "action_type": "submit_decision",
91
- "decision": "CLEAR",
92
- "evidence_links": [],
93
- }
94
- }
95
- )
96
-
97
-
98
- def _model_action(client: OpenAI, observation: Any, history: list[str]) -> AmlAction:
99
- history_block = "\n".join(history[-5:]) if history else "No prior steps."
100
- user_prompt = (
101
- f"Alert:\n{observation.alert_details}\n\n"
102
- f"Observation:\n{json.dumps(observation.model_dump(), separators=(",", ":"), ensure_ascii=True)}\n\n"
103
- f"History:\n{history_block}\n\n"
104
- "Return the next JSON action."
105
- )
106
-
107
  try:
108
  completion = client.chat.completions.create(
109
  model=MODEL_NAME,
@@ -111,75 +66,82 @@ def _model_action(client: OpenAI, observation: Any, history: list[str]) -> AmlAc
111
  {"role": "system", "content": SYSTEM_PROMPT},
112
  {"role": "user", "content": user_prompt},
113
  ],
114
- temperature=0.0,
115
- max_tokens=256,
116
  )
117
- content = (completion.choices[0].message.content or "").strip()
118
- if not content:
119
- return _fallback_action()
120
- try:
121
- return AmlAction.model_validate_json(content)
122
- except Exception:
123
- return AmlAction.model_validate(json.loads(content))
124
- except Exception:
125
- return _fallback_action()
126
-
127
-
128
- def run_episode(client: OpenAI, env: AmlEnv, task_name: str) -> tuple[bool, int, float, list[float]]:
129
- history: list[str] = []
130
- rewards: list[float] = []
131
- steps_taken = 0
132
- success = False
133
- score = 0.0
134
-
135
- log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
136
-
137
- observation = env.reset(task=task_name)
138
- final_done = False
139
- final_error: Optional[str] = None
140
-
141
- for step in range(1, MAX_STEPS + 1):
142
- if observation.done:
143
- break
144
-
145
- action = _model_action(client, observation, history)
146
- result = env.step(action)
147
- observation = result.observation
148
-
149
- reward = float(result.reward or 0.0)
150
- final_done = bool(result.done)
151
- final_error = observation.error_message
152
- steps_taken = step
153
- rewards.append(reward)
154
-
155
- action_text = _clean_text(_format_action(action))
156
- log_step(step=step, action=action_text, reward=reward, done=final_done, error=final_error)
157
-
158
- history.append(
159
- f"step={step} action={action_text} reward={_format_reward(reward)} done={str(final_done).lower()} "
160
- f"error={_format_error(final_error)} result={_clean_text(str(observation.last_action_result))}"
161
- )
162
-
163
- if final_done:
164
- break
165
-
166
- score = max(0.0, min(1.0, sum(rewards)))
167
- success = bool(final_done and final_error is None and score > 0.0)
168
- return success, steps_taken, score, rewards
169
-
170
-
171
- def main() -> None:
172
- client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
173
- env = _build_env()
174
-
175
- try:
176
- for task_name in TASKS:
177
- success, steps_taken, score, rewards = run_episode(client, env, task_name)
178
- _ = score
179
- log_end(success=success, steps=steps_taken, rewards=rewards)
180
- finally:
181
- env.close()
182
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  if __name__ == "__main__":
185
- main()
 
 
 
 
 
1
  """
2
+ AML Investigator - Baseline Inference Script
3
+ Loops through all 3 tasks to satisfy the Phase 2 Validator.
4
+ """
5
+ import asyncio
6
  import os
7
+ import json
8
+ import textwrap
9
+ import sys
10
+ from typing import List, Optional
11
  from openai import OpenAI
12
 
13
+ # Adjust the import based on your openenv server setup
14
+ # If running locally without docker wrapper for validation, you might need to import your Env directly
15
+ from server.AML_env_environment import AmlEnvironment
16
+ from models import AmlAction
 
 
 
 
 
17
 
18
 
19
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") or "http://127.0.0.1:1234"
 
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
23
  ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:7860")
24
  TASK_NAME = os.getenv("TASK_NAME", "aml_easy")
 
 
25
  TASKS = ["aml_easy", "aml_medium", "aml_hard"]
26
+ BENCHMARK = "aml_investigator"
27
+ MAX_STEPS = 25
28
+
29
+ SYSTEM_PROMPT = textwrap.dedent(
30
+ """
31
+ You are a Tier 1 AML Compliance Investigator.
32
+ You must investigate the provided alert by querying the bank's internal APIs.
33
+
34
+ You have a strict API budget. Be efficient.
35
+ Respond with EXACTLY ONE valid JSON object representing your action. Do not include markdown formatting or explanations.
36
+
37
+ Available Action JSON Schemas:
38
+ 1. {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
39
+ 2. {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
40
+ 3. {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
41
+ 4. {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} (Use "CLEAR" for False Positives with empty evidence_links).
42
+ """
43
+ ).strip()
 
 
 
 
 
 
 
 
 
44
 
45
  def log_start(task: str, env: str, model: str) -> None:
46
  print(f"[START] task={task} env={env} model={model}", flush=True)
47
 
48
 
49
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
50
+ error_val = error if error else "null"
51
+ done_val = str(done).lower()
52
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
53
+
54
+ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
55
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
56
+ print(f"[END] success={str(success).lower()} steps={steps} score={score:.2f} rewards={rewards_str}", flush=True)
57
+
58
+ def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str:
59
+ history_block = "\n".join(history[-5:]) if history else "No previous steps."
60
+ user_prompt = f"Observation:\n{json.dumps(obs_dict, indent=2)}\n\nHistory:\n{history_block}\n\nProvide your next JSON action:"
61
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  try:
63
  completion = client.chat.completions.create(
64
  model=MODEL_NAME,
 
66
  {"role": "system", "content": SYSTEM_PROMPT},
67
  {"role": "user", "content": user_prompt},
68
  ],
69
+ temperature=0.1,
70
+ max_tokens=200,
71
  )
72
+ return (completion.choices[0].message.content or "").strip()
73
+ except Exception as exc:
74
+ print(f"[DEBUG] Model request failed: {exc}", file=sys.stderr, flush=True)
75
+ # Fallback to prevent crash
76
+ return '{"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}'
77
+
78
+ async def main() -> None:
79
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
80
+
81
+ # Initialize your environment natively for the baseline script
82
+ env = AmlEnvironment()
83
+
84
+ for task_name in TASKS:
85
+ history: List[str] = []
86
+ rewards: List[float] = []
87
+ steps_taken = 0
88
+ score = 0.0
89
+ success = False
90
+
91
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
+ try:
94
+ obs = env.reset(task=task_name)
95
+
96
+ for step in range(1, MAX_STEPS + 1):
97
+ if obs.done:
98
+ break
99
+
100
+ obs_dict = obs.model_dump()
101
+ action_str = get_model_message(client, obs_dict, history)
102
+
103
+ # Parse LLM string to Pydantic Model
104
+ try:
105
+ # Strip possible markdown backticks
106
+ clean_str = action_str.replace("```json", "").replace("```", "").strip()
107
+ action_json = json.loads(clean_str)
108
+ action_obj = AmlAction.model_validate(action_json)
109
+ error = None
110
+ except Exception as e:
111
+ # Errors are data! If the LLM writes bad JSON, we catch it and force a dummy action
112
+ # so the environment can return a schema error to the LLM.
113
+ error = f"JSON Parse/Schema Error: {str(e)}"
114
+ action_obj = AmlAction.model_validate(
115
+ {
116
+ "action": {
117
+ "action_type": "submit_decision",
118
+ "decision": "CLEAR",
119
+ "evidence_links": [],
120
+ }
121
+ }
122
+ )
123
+
124
+ obs = env.step(action_obj)
125
+
126
+ reward = obs.reward or 0.0
127
+ done = obs.done
128
+
129
+ rewards.append(reward)
130
+ steps_taken = step
131
+
132
+ log_step(step=step, action=action_str.replace('\n', ''), reward=reward, done=done, error=error)
133
+ history.append(f"Step {step}: Action: {action_str} -> Result: {obs.last_action_result} | Error: {obs.error_message}")
134
+
135
+ if done:
136
+ break
137
+
138
+ # Calculate a baseline score for the stdout logs (Graders handle real scoring)
139
+ score = sum(rewards) + 1.0 if "submit_decision" in (obs.last_action or "") else 0.0
140
+ score = min(max(score, 0.01), 0.99)
141
+ success = score > 0.5
142
+
143
+ finally:
144
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
145
 
146
  if __name__ == "__main__":
147
+ asyncio.run(main())