DataBoySu commited on
Commit
7ae8bca
·
1 Parent(s): 2a6078a

fix issues

Browse files
Files changed (2) hide show
  1. inference.py +163 -114
  2. pre-val.sh +6 -14
inference.py CHANGED
@@ -1,61 +1,109 @@
 
 
 
 
1
  """
2
- AML Investigator - Baseline Inference Script
3
- Loops through all 3 tasks to satisfy the Phase 2 Validator.
4
- """
5
- import asyncio
6
- import os
7
  import json
8
- import textwrap
9
- from typing import List, Optional
 
 
10
  from openai import OpenAI
11
 
12
- # Adjust the import based on your openenv server setup
13
- from openenv.core.env_server.interfaces import Environment
14
- # If running locally without docker wrapper for validation, you might need to import your Env directly
15
- from server.AML_env_environment import AmlEnvironment
16
- from models import AmlAction
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
- API_KEY = os.getenv("HF_TOKEN") or "lm-studio"
19
- API_BASE_URL = os.getenv("API_BASE_URL") or "https://router.huggingface.co/v1" or "http://localhost:1234/v1"
20
- MODEL_NAME = os.getenv("MODEL_NAME") or "openai/gpt-oss-20b"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
 
22
- # Must match openenv.yaml EXACTLY
23
- TASKS = ["aml_easy", "aml_medium", "aml_hard"]
24
- BENCHMARK = "aml_investigator"
25
- MAX_STEPS = 25
26
-
27
- SYSTEM_PROMPT = textwrap.dedent(
28
- """
29
- You are a Tier 1 AML Compliance Investigator.
30
- You must investigate the provided alert by querying the bank's internal APIs.
31
-
32
- You have a strict API budget. Be efficient.
33
- Respond with EXACTLY ONE valid JSON object representing your action. Do not include markdown formatting or explanations.
34
-
35
- Available Action JSON Schemas:
36
- 1. {"action": {"action_type": "query_transactions", "account_id": "ACC-XXXX", "limit": 10, "offset": 0}}
37
- 2. {"action": {"action_type": "search_transactions", "account_id": "ACC-XXXX", "keyword": "invoice"}}
38
- 3. {"action": {"action_type": "get_kyc_record", "entity_id": "ENT-XXXX"}}
39
- 4. {"action": {"action_type": "submit_decision", "decision": "FRAUD", "evidence_links": ["ACC-1234"]}} (Use "CLEAR" for False Positives with empty evidence_links).
40
- """
41
- ).strip()
42
 
43
  def log_start(task: str, env: str, model: str) -> None:
44
  print(f"[START] task={task} env={env} model={model}", flush=True)
45
 
 
46
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
47
- error_val = error if error else "null"
48
- done_val = str(done).lower()
49
- print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
50
-
51
- def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> None:
52
- rewards_str = ",".join(f"{r:.2f}" for r in rewards)
53
- print(f"[END] success={str(success).lower()} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
54
-
55
- def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str:
56
- history_block = "\n".join(history[-5:]) if history else "No previous steps."
57
- user_prompt = f"Observation:\n{json.dumps(obs_dict, indent=2)}\n\nHistory:\n{history_block}\n\nProvide your next JSON action:"
58
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  try:
60
  completion = client.chat.completions.create(
61
  model=MODEL_NAME,
@@ -63,74 +111,75 @@ def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str
63
  {"role": "system", "content": SYSTEM_PROMPT},
64
  {"role": "user", "content": user_prompt},
65
  ],
66
- temperature=0.1,
67
- max_tokens=200,
68
  )
69
- return (completion.choices[0].message.content or "").strip()
70
- except Exception as exc:
71
- print(f"[DEBUG] Model request failed: {exc}", flush=True)
72
- # Fallback to prevent crash
73
- return '{"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}'
74
-
75
- async def main() -> None:
76
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
77
-
78
- # Initialize your environment natively for the baseline script
79
- env = AmlEnvironment()
80
-
81
- for task_name in TASKS:
82
- history: List[str] = []
83
- rewards: List[float] = []
84
- steps_taken = 0
85
- score = 0.0
86
- success = False
87
-
88
- log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
89
-
90
  try:
91
- obs = env.reset(task=task_name)
92
-
93
- for step in range(1, MAX_STEPS + 1):
94
- if obs.done:
95
- break
96
-
97
- obs_dict = obs.model_dump()
98
- action_str = get_model_message(client, obs_dict, history)
99
-
100
- # Parse LLM string to Pydantic Model
101
- try:
102
- # Strip possible markdown backticks
103
- clean_str = action_str.replace("```json", "").replace("```", "").strip()
104
- action_json = json.loads(clean_str)
105
- action_obj = AmlAction.model_validate(action_json)
106
- error = None
107
- except Exception as e:
108
- # Errors are data! If the LLM writes bad JSON, we catch it and force a dummy action
109
- # so the environment can return a schema error to the LLM.
110
- error = f"JSON Parse/Schema Error: {str(e)}"
111
- action_obj = AmlAction(action={"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []})
112
-
113
- obs = env.step(action_obj)
114
-
115
- reward = obs.reward or 0.0
116
- done = obs.done
117
-
118
- rewards.append(reward)
119
- steps_taken = step
120
-
121
- log_step(step=step, action=action_str.replace('\n', ''), reward=reward, done=done, error=error)
122
- history.append(f"Step {step}: Action: {action_str} -> Result: {obs.last_action_result} | Error: {obs.error_message}")
123
-
124
- if done:
125
- break
126
-
127
- # Calculate a baseline score for the stdout logs (Graders handle real scoring)
128
- score = sum(rewards) + 1.0 if "submit_decision" in (obs.last_action or "") else 0.0
129
- score = min(max(score, 0.0), 1.0)
130
- success = score > 0.5
131
-
132
- finally:
133
- log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
134
 
135
  if __name__ == "__main__":
136
- asyncio.run(main())
 
1
+ """Baseline inference runner for AML_env.
2
+
3
+ The script supports local LM Studio via an OpenAI-compatible base URL and keeps
4
+ the multi-task loop expected by the project validator.
5
  """
6
+
7
+ from __future__ import annotations
8
+
 
 
9
  import json
10
+ import os
11
+ from pathlib import Path
12
+ from typing import Any, Optional
13
+
14
  from openai import OpenAI
15
 
16
+ try:
17
+ from AML_env.client import AmlEnv
18
+ from AML_env.models import AmlAction
19
+ except Exception:
20
+ ROOT_DIR = Path(__file__).resolve().parent
21
+ if str(ROOT_DIR) not in os.sys.path:
22
+ os.sys.path.insert(0, str(ROOT_DIR))
23
+ from client import AmlEnv
24
+ from models import AmlAction
25
+
26
+
27
+ API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") or "http://127.0.0.1:1234"
28
+ MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
29
+ HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
30
+ LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
31
+ ENV_BASE_URL = os.getenv("ENV_BASE_URL", "http://127.0.0.1:7860")
32
+ TASK_NAME = os.getenv("TASK_NAME", "aml_easy")
33
+ BENCHMARK = os.getenv("BENCHMARK", "AML_env")
34
+ MAX_STEPS = int(os.getenv("MAX_STEPS", "25"))
35
+ TASKS = ["aml_easy", "aml_medium", "aml_hard"]
36
 
37
+ if not HF_TOKEN:
38
+ raise ValueError("HF_TOKEN environment variable is required")
39
+
40
+ SYSTEM_PROMPT = (
41
+ "You are a Tier 1 AML Compliance Investigator. "
42
+ "Return exactly one JSON object with the nested shape {\"action\": {...}}. "
43
+ "Allowed action types: query_transactions, search_transactions, get_kyc_record, submit_decision. "
44
+ "Do not output markdown, code fences, or explanations."
45
+ )
46
+
47
+
48
+ def _clean_text(value: str) -> str:
49
+ return value.replace("\n", " ").replace("\r", " ").strip()
50
+
51
+
52
+ def _format_reward(value: float) -> str:
53
+ return f"{value:.2f}"
54
+
55
+
56
+ def _format_action(action: AmlAction) -> str:
57
+ return json.dumps(action.model_dump(), separators=(",", ":"), ensure_ascii=True)
58
+
59
+
60
+ def _format_error(error: Optional[str]) -> str:
61
+ return error if error else "null"
62
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
  def log_start(task: str, env: str, model: str) -> None:
65
  print(f"[START] task={task} env={env} model={model}", flush=True)
66
 
67
+
68
  def log_step(step: int, action: str, reward: float, done: bool, error: Optional[str]) -> None:
69
+ print(
70
+ f"[STEP] step={step} action={action} reward={_format_reward(reward)} done={str(done).lower()} error={_format_error(error)}",
71
+ flush=True,
72
+ )
73
+
74
+
75
+ def log_end(success: bool, steps: int, rewards: list[float]) -> None:
76
+ rewards_str = ",".join(_format_reward(r) for r in rewards)
77
+ print(f"[END] success={str(success).lower()} steps={steps} rewards={rewards_str}", flush=True)
78
+
79
+
80
+ def _build_env() -> AmlEnv:
81
+ if LOCAL_IMAGE_NAME:
82
+ return AmlEnv.from_docker_image(LOCAL_IMAGE_NAME)
83
+ return AmlEnv(base_url=ENV_BASE_URL)
84
+
85
+
86
+ def _fallback_action() -> AmlAction:
87
+ return AmlAction.model_validate(
88
+ {
89
+ "action": {
90
+ "action_type": "submit_decision",
91
+ "decision": "CLEAR",
92
+ "evidence_links": [],
93
+ }
94
+ }
95
+ )
96
+
97
+
98
+ def _model_action(client: OpenAI, observation: Any, history: list[str]) -> AmlAction:
99
+ history_block = "\n".join(history[-5:]) if history else "No prior steps."
100
+ user_prompt = (
101
+ f"Alert:\n{observation.alert_details}\n\n"
102
+ f"Observation:\n{json.dumps(observation.model_dump(), separators=(",", ":"), ensure_ascii=True)}\n\n"
103
+ f"History:\n{history_block}\n\n"
104
+ "Return the next JSON action."
105
+ )
106
+
107
  try:
108
  completion = client.chat.completions.create(
109
  model=MODEL_NAME,
 
111
  {"role": "system", "content": SYSTEM_PROMPT},
112
  {"role": "user", "content": user_prompt},
113
  ],
114
+ temperature=0.0,
115
+ max_tokens=256,
116
  )
117
+ content = (completion.choices[0].message.content or "").strip()
118
+ if not content:
119
+ return _fallback_action()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  try:
121
+ return AmlAction.model_validate_json(content)
122
+ except Exception:
123
+ return AmlAction.model_validate(json.loads(content))
124
+ except Exception:
125
+ return _fallback_action()
126
+
127
+
128
+ def run_episode(client: OpenAI, env: AmlEnv, task_name: str) -> tuple[bool, int, float, list[float]]:
129
+ history: list[str] = []
130
+ rewards: list[float] = []
131
+ steps_taken = 0
132
+ success = False
133
+ score = 0.0
134
+
135
+ log_start(task=task_name, env=BENCHMARK, model=MODEL_NAME)
136
+
137
+ observation = env.reset(task=task_name)
138
+ final_done = False
139
+ final_error: Optional[str] = None
140
+
141
+ for step in range(1, MAX_STEPS + 1):
142
+ if observation.done:
143
+ break
144
+
145
+ action = _model_action(client, observation, history)
146
+ result = env.step(action)
147
+ observation = result.observation
148
+
149
+ reward = float(result.reward or 0.0)
150
+ final_done = bool(result.done)
151
+ final_error = observation.error_message
152
+ steps_taken = step
153
+ rewards.append(reward)
154
+
155
+ action_text = _clean_text(_format_action(action))
156
+ log_step(step=step, action=action_text, reward=reward, done=final_done, error=final_error)
157
+
158
+ history.append(
159
+ f"step={step} action={action_text} reward={_format_reward(reward)} done={str(final_done).lower()} "
160
+ f"error={_format_error(final_error)} result={_clean_text(str(observation.last_action_result))}"
161
+ )
162
+
163
+ if final_done:
164
+ break
165
+
166
+ score = max(0.0, min(1.0, sum(rewards)))
167
+ success = bool(final_done and final_error is None and score > 0.0)
168
+ return success, steps_taken, score, rewards
169
+
170
+
171
+ def main() -> None:
172
+ client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
173
+ env = _build_env()
174
+
175
+ try:
176
+ for task_name in TASKS:
177
+ success, steps_taken, score, rewards = run_episode(client, env, task_name)
178
+ _ = score
179
+ log_end(success=success, steps=steps_taken, rewards=rewards)
180
+ finally:
181
+ env.close()
182
+
183
 
184
  if __name__ == "__main__":
185
+ main()
pre-val.sh CHANGED
@@ -135,19 +135,17 @@ fi
135
 
136
  if [ -f "$REPO_DIR/Dockerfile" ]; then
137
  DOCKER_CONTEXT="$REPO_DIR"
138
- DOCKERFILE_PATH="$REPO_DIR/Dockerfile"
139
  elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
140
- DOCKER_CONTEXT="$REPO_DIR"
141
- DOCKERFILE_PATH="$REPO_DIR/server/Dockerfile"
142
  else
143
  fail "No Dockerfile found in repo root or server/ directory"
144
  stop_at "Step 2"
145
  fi
146
 
147
- log " Found Dockerfile at $DOCKERFILE_PATH"
148
 
149
  BUILD_OK=false
150
- BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build -f "$DOCKERFILE_PATH" "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
151
 
152
  if [ "$BUILD_OK" = true ]; then
153
  pass "Docker build succeeded"
@@ -159,20 +157,14 @@ fi
159
 
160
  log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
161
 
162
- OPENENV_BIN=""
163
- if command -v openenv &>/dev/null; then
164
- OPENENV_BIN="openenv"
165
- elif [ -x "$REPO_DIR/.venv/bin/openenv" ]; then
166
- OPENENV_BIN="$REPO_DIR/.venv/bin/openenv"
167
- else
168
  fail "openenv command not found"
169
  hint "Install it: pip install openenv-core"
170
- hint "Or create a local venv in the repo with .venv/bin/openenv available."
171
  stop_at "Step 3"
172
  fi
173
 
174
  VALIDATE_OK=false
175
- VALIDATE_OUTPUT=$(cd "$REPO_DIR" && "$OPENENV_BIN" validate 2>&1) && VALIDATE_OK=true
176
 
177
  if [ "$VALIDATE_OK" = true ]; then
178
  pass "openenv validate passed"
@@ -190,4 +182,4 @@ printf "${GREEN}${BOLD} Your submission is ready to submit.${NC}\n"
190
  printf "${BOLD}========================================${NC}\n"
191
  printf "\n"
192
 
193
- exit 0
 
135
 
136
  if [ -f "$REPO_DIR/Dockerfile" ]; then
137
  DOCKER_CONTEXT="$REPO_DIR"
 
138
  elif [ -f "$REPO_DIR/server/Dockerfile" ]; then
139
+ DOCKER_CONTEXT="$REPO_DIR/server"
 
140
  else
141
  fail "No Dockerfile found in repo root or server/ directory"
142
  stop_at "Step 2"
143
  fi
144
 
145
+ log " Found Dockerfile in $DOCKER_CONTEXT"
146
 
147
  BUILD_OK=false
148
+ BUILD_OUTPUT=$(run_with_timeout "$DOCKER_BUILD_TIMEOUT" docker build "$DOCKER_CONTEXT" 2>&1) && BUILD_OK=true
149
 
150
  if [ "$BUILD_OK" = true ]; then
151
  pass "Docker build succeeded"
 
157
 
158
  log "${BOLD}Step 3/3: Running openenv validate${NC} ..."
159
 
160
+ if ! command -v openenv &>/dev/null; then
 
 
 
 
 
161
  fail "openenv command not found"
162
  hint "Install it: pip install openenv-core"
 
163
  stop_at "Step 3"
164
  fi
165
 
166
  VALIDATE_OK=false
167
+ VALIDATE_OUTPUT=$(cd "$REPO_DIR" && openenv validate 2>&1) && VALIDATE_OK=true
168
 
169
  if [ "$VALIDATE_OK" = true ]; then
170
  pass "openenv validate passed"
 
182
  printf "${BOLD}========================================${NC}\n"
183
  printf "\n"
184
 
185
+ exit 0