parth-1 commited on
Commit
929006e
·
unverified ·
1 Parent(s): ac23714

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +55 -35
inference.py CHANGED
@@ -3,7 +3,7 @@ import json
3
  import requests
4
  from openai import OpenAI
5
 
6
- # 1. 🚨 MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
8
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
9
  MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3-70b-chat-hf")
@@ -22,9 +22,23 @@ TASKS = [
22
  "task_4_targeting"
23
  ]
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  def get_llm_action(observation_data):
26
  """Asks the LLM what action to take based on the ad observation."""
27
-
28
  system_prompt = """You are an expert Meta Ad-Policy Moderator AI.
29
  Evaluate the ad and output a decision. Using tools costs -0.05 points, so be efficient.
30
 
@@ -57,49 +71,55 @@ def get_llm_action(observation_data):
57
  "reasoning": result.get("reasoning", "Fallback reasoning")
58
  }
59
  except Exception as e:
60
- print(f"⚠️ LLM Call Failed: {e}. Defaulting to safe fallback.")
61
  return {"action_type": "analyze_image", "reasoning": "Error recovery."}
62
 
63
  def main() -> None:
64
- print("🚀 Starting Meta Ad-Policy Automated Inference...")
65
- total_score = 0.0
66
-
67
  for task_id in TASKS:
68
- print(f"\n--- 🎬 Starting Episode: {task_id} ---")
 
 
 
 
69
 
70
  try:
71
  res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
72
  if res.status_code != 200:
73
- print(f" Env connection failed. Check if Docker is running on port 8000.")
74
- return
75
- except requests.exceptions.ConnectionError:
76
- print(f"❌ Env connection refused. Is your OpenEnv Docker container running?")
77
- return
78
-
79
- observation = res.json()
80
- done = False
81
- step_count = 0
82
-
83
- while not done and step_count < MAX_STEPS:
84
- step_count += 1
85
- print(f" Step {step_count} | Status: {observation.get('status_message', 'No status')}")
86
-
87
- action_payload = get_llm_action(observation)
88
- print(f" 🤖 Agent Action: {action_payload['action_type'].upper()}")
89
-
90
- step_res = requests.post(f"{ENV_URL}/step", json=action_payload)
91
- step_data = step_res.json()
92
 
93
- # Extract from the OpenEnv schema
94
- observation = step_data.get("observation", step_data)
95
- done = observation.get("done", False)
96
- reward = observation.get("reward", 0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
- if done:
99
- print(f" ✅ Episode Finished! Final Step Reward: {reward}")
100
- total_score += reward
101
 
102
- print(f"\n🎉 Evaluation Complete! Total Agent Score: {total_score} / {len(TASKS)}")
 
 
103
 
104
  if __name__ == "__main__":
105
- main()
 
3
  import requests
4
  from openai import OpenAI
5
 
6
+ # 1. MANDATORY VARIABLES EXACTLY AS REQUESTED BY SCALAR
7
  API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
8
  API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY", "dummy_local_token")
9
  MODEL_NAME = os.getenv("MODEL_NAME", "meta-llama/Llama-3-70b-chat-hf")
 
22
  "task_4_targeting"
23
  ]
24
 
25
+ # --- STRICT GRADING LOGGERS ---
26
+ def log_start(task: str, env: str, model: str) -> None:
27
+ print(f"[START] task={task} env={env} model={model}", flush=True)
28
+
29
+ def log_step(step: int, action: str, reward: float, done: bool, error: str = None) -> None:
30
+ error_val = error if error else "null"
31
+ done_val = str(done).lower()
32
+ print(f"[STEP] step={step} action={action} reward={reward:.2f} done={done_val} error={error_val}", flush=True)
33
+
34
+ def log_end(success: bool, steps: int, score: float, rewards: list) -> None:
35
+ rewards_str = ",".join(f"{r:.2f}" for r in rewards)
36
+ success_val = str(success).lower()
37
+ print(f"[END] success={success_val} steps={steps} score={score:.3f} rewards={rewards_str}", flush=True)
38
+ # ------------------------------
39
+
40
  def get_llm_action(observation_data):
41
  """Asks the LLM what action to take based on the ad observation."""
 
42
  system_prompt = """You are an expert Meta Ad-Policy Moderator AI.
43
  Evaluate the ad and output a decision. Using tools costs -0.05 points, so be efficient.
44
 
 
71
  "reasoning": result.get("reasoning", "Fallback reasoning")
72
  }
73
  except Exception as e:
 
74
  return {"action_type": "analyze_image", "reasoning": "Error recovery."}
75
 
76
  def main() -> None:
 
 
 
77
  for task_id in TASKS:
78
+ log_start(task=task_id, env="meta_ad_policy_sandbox", model=MODEL_NAME)
79
+
80
+ rewards = []
81
+ steps_taken = 0
82
+ success = False
83
 
84
  try:
85
  res = requests.post(f"{ENV_URL}/reset", json={"task_id": task_id})
86
  if res.status_code != 200:
87
+ log_step(step=1, action="reset_failed", reward=0.0, done=True, error=f"HTTP {res.status_code}")
88
+ log_end(success=False, steps=0, score=0.0, rewards=[])
89
+ continue
90
+
91
+ data = res.json()
92
+ observation = data.get("observation", data)
93
+ done = False
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
+ while not done and steps_taken < MAX_STEPS:
96
+ steps_taken += 1
97
+
98
+ # Get action from LLM
99
+ action_payload = get_llm_action(observation)
100
+ action_str = action_payload["action_type"]
101
+
102
+ # Execute action
103
+ step_res = requests.post(f"{ENV_URL}/step", json=action_payload)
104
+ step_data = step_res.json()
105
+
106
+ # Parse response perfectly
107
+ observation = step_data.get("observation", {})
108
+ done = step_data.get("done", False)
109
+ reward = step_data.get("reward", 0.0)
110
+
111
+ rewards.append(reward)
112
+ log_step(step=steps_taken, action=action_str, reward=reward, done=done, error=None)
113
+
114
+ # Calculate final score (Clamp between 0 and 1)
115
+ score = min(max(sum(rewards), 0.0), 1.0)
116
+ success = score > 0.0
117
 
118
+ log_end(success=success, steps=steps_taken, score=score, rewards=rewards)
 
 
119
 
120
+ except Exception as e:
121
+ log_step(step=steps_taken+1, action="exception", reward=0.0, done=True, error=str(e).replace("\n", " "))
122
+ log_end(success=False, steps=steps_taken, score=0.0, rewards=rewards)
123
 
124
  if __name__ == "__main__":
125
+ main()