ChaitanyaRasane commited on
Commit
c5b1020
·
1 Parent(s): 7e8d400

feat: add inference.py for OpenEnv compliance

Browse files
Files changed (1) hide show
  1. inference.py +55 -0
inference.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ from env import UIEnv, Observation, Action
4
+ from baseline import agent_policy
5
+ from openai import OpenAI
6
+
7
+ def run_inference(task_id="easy"):
8
+ """
9
+ Standard OpenEnv inference entry point.
10
+ """
11
+ # 1. Setup Environment
12
+ env = UIEnv(seed=42, task=task_id)
13
+ obs = env.reset()
14
+
15
+ # 2. Setup Client (OpenAI or HF Router)
16
+ openai_key = os.getenv("OPENAI_API_KEY")
17
+ hf_token = os.getenv("HF_TOKEN")
18
+
19
+ if openai_key:
20
+ client = OpenAI(api_key=openai_key)
21
+ model_name = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
22
+ elif hf_token:
23
+ client = OpenAI(
24
+ base_url="https://router.huggingface.co/v1",
25
+ api_key=hf_token
26
+ )
27
+ model_name = "katanemo/Arch-Router-1.5B"
28
+ else:
29
+ # Fallback to no-client (heuristic only)
30
+ client = None
31
+ model_name = None
32
+
33
+ # 3. Perform Inference Step
34
+ if client:
35
+ action = agent_policy(client, obs, model_name)
36
+ else:
37
+ # Fallback to heuristic if no API key is provided
38
+ from baseline import heuristic_policy
39
+ action = heuristic_policy(obs)
40
+
41
+ # 4. Step Environment
42
+ new_obs, reward, done, info = env.step(action)
43
+
44
+ result = {
45
+ "action": action.type,
46
+ "reward": reward,
47
+ "done": done,
48
+ "outcome": info.get("outcome")
49
+ }
50
+
51
+ print(json.dumps(result, indent=2))
52
+ return result
53
+
54
+ if __name__ == "__main__":
55
+ run_inference()