DataBoySu commited on
Commit
a4c032a
·
1 Parent(s): bfed85f

pre thought

Browse files
Files changed (2) hide show
  1. inference.py +96 -6
  2. models.py +2 -1
inference.py CHANGED
@@ -16,7 +16,7 @@ from server.AML_env_environment import AmlEnvironment
16
  from models import AmlAction
17
 
18
 
19
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1") or "http://127.0.0.1:1234"
20
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
21
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
@@ -42,6 +42,70 @@ SYSTEM_PROMPT = textwrap.dedent(
42
  """
43
  ).strip()
44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
  def log_start(task: str, env: str, model: str) -> None:
46
  print(f"[START] task={task} env={env} model={model}", flush=True)
47
 
@@ -69,11 +133,37 @@ def get_model_message(client: OpenAI, obs_dict: dict, history: List[str]) -> str
69
  temperature=0.1,
70
  max_tokens=200,
71
  )
72
- return (completion.choices[0].message.content or "").strip()
73
- except Exception as exc:
74
- print(f"[DEBUG] Model request failed: {exc}", file=sys.stderr, flush=True)
75
- # Fallback to prevent crash
76
- return '{"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
  async def main() -> None:
79
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
 
16
  from models import AmlAction
17
 
18
 
19
+ API_BASE_URL = os.getenv("API_BASE_URL") or "http://127.0.0.1:1234/v1"
20
  MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-20b")
21
  HF_TOKEN = os.getenv("HF_TOKEN") or "lm-studio"
22
  LOCAL_IMAGE_NAME = os.getenv("LOCAL_IMAGE_NAME")
 
42
  """
43
  ).strip()
44
 
45
+ FALLBACK_ACTION_JSON = '{"action": {"action_type": "submit_decision", "decision": "CLEAR", "evidence_links": []}}'
46
+
47
+
48
+ def _extract_text_from_chat_completion(completion: object) -> str:
49
+ choices = getattr(completion, "choices", None) or []
50
+ if not choices:
51
+ raise ValueError("Model response has no choices")
52
+
53
+ first_choice = choices[0]
54
+ message = getattr(first_choice, "message", None)
55
+ if message is None:
56
+ raise ValueError("Model response choice has no message")
57
+
58
+ content = getattr(message, "content", None)
59
+ if isinstance(content, str) and content.strip():
60
+ return content.strip()
61
+
62
+ if isinstance(content, list):
63
+ chunks: List[str] = []
64
+ for item in content:
65
+ if isinstance(item, dict):
66
+ text_val = item.get("text")
67
+ if isinstance(text_val, str):
68
+ chunks.append(text_val)
69
+ merged = "".join(chunks).strip()
70
+ if merged:
71
+ return merged
72
+
73
+ raise ValueError("Model response content is empty")
74
+
75
+
76
+ def _extract_text_from_responses_api(response: object) -> str:
77
+ output_text = getattr(response, "output_text", None)
78
+ if isinstance(output_text, str) and output_text.strip():
79
+ return output_text.strip()
80
+
81
+ output = getattr(response, "output", None) or []
82
+ chunks: List[str] = []
83
+ for item in output:
84
+ content = getattr(item, "content", None) or []
85
+ for part in content:
86
+ text_val = getattr(part, "text", None)
87
+ if isinstance(text_val, str):
88
+ chunks.append(text_val)
89
+
90
+ merged = "".join(chunks).strip()
91
+ if merged:
92
+ return merged
93
+
94
+ raise ValueError("Responses API output is empty")
95
+
96
+
97
+ def _extract_text_from_completions_api(completion: object) -> str:
98
+ choices = getattr(completion, "choices", None) or []
99
+ if not choices:
100
+ raise ValueError("Completions API response has no choices")
101
+
102
+ first_choice = choices[0]
103
+ text_val = getattr(first_choice, "text", None)
104
+ if isinstance(text_val, str) and text_val.strip():
105
+ return text_val.strip()
106
+
107
+ raise ValueError("Completions API response text is empty")
108
+
109
  def log_start(task: str, env: str, model: str) -> None:
110
  print(f"[START] task={task} env={env} model={model}", flush=True)
111
 
 
133
  temperature=0.1,
134
  max_tokens=200,
135
  )
136
+ return _extract_text_from_chat_completion(completion)
137
+ except Exception as chat_exc:
138
+ # Retry via Responses API for OpenAI-compatible providers that do not
139
+ # populate chat.completions choices consistently.
140
+ try:
141
+ response = client.responses.create(
142
+ model=MODEL_NAME,
143
+ instructions=SYSTEM_PROMPT,
144
+ input=user_prompt,
145
+ max_output_tokens=200,
146
+ )
147
+ return _extract_text_from_responses_api(response)
148
+ except Exception as responses_exc:
149
+ try:
150
+ completion = client.completions.create(
151
+ model=MODEL_NAME,
152
+ prompt=f"{SYSTEM_PROMPT}\n\n{user_prompt}",
153
+ temperature=0.1,
154
+ max_tokens=200,
155
+ )
156
+ return _extract_text_from_completions_api(completion)
157
+ except Exception as completions_exc:
158
+ print(
159
+ (
160
+ "[DEBUG] Model request failed: "
161
+ f"chat={chat_exc}; responses={responses_exc}; completions={completions_exc}"
162
+ ),
163
+ file=sys.stderr,
164
+ flush=True,
165
+ )
166
+ return FALLBACK_ACTION_JSON
167
 
168
  async def main() -> None:
169
  client = OpenAI(base_url=API_BASE_URL, api_key=HF_TOKEN)
models.py CHANGED
@@ -51,4 +51,5 @@ class SubmitDecision(Action):
51
  class AmlAction(Action):
52
  action: Union[QueryTransactions, SearchTransactions, GetKYCRecord, SubmitDecision] = Field(
53
  discriminator='action_type'
54
- )
 
 
51
  class AmlAction(Action):
52
  action: Union[QueryTransactions, SearchTransactions, GetKYCRecord, SubmitDecision] = Field(
53
  discriminator='action_type'
54
+ )
55
+