Omkar1806 commited on
Commit
5e1996c
·
verified ·
1 Parent(s): 5881e78

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +42 -10
inference.py CHANGED
@@ -7,9 +7,10 @@ from env import EmailTriageEnv
7
  from app import smart_agent_logic
8
 
9
 
10
- API_BASE_URL = os.getenv("API_BASE_URL", "https://router.huggingface.co/v1")
11
- MODEL_NAME = os.getenv("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
12
- API_KEY = os.getenv("HF_TOKEN") or os.getenv("API_KEY") or os.getenv("OPENAI_API_KEY")
 
13
 
14
  TASK_NAME = os.getenv("MY_ENV_V4_TASK", "easy")
15
  BENCHMARK = "email_triage_env"
@@ -42,15 +43,12 @@ def log_end(success: bool, steps: int, score: float, rewards: List[float]) -> No
42
 
43
 
44
  def main():
45
- # ✅ SAFE OpenAI initialization (FIX)
46
- client = None
47
  try:
48
- if API_KEY:
49
- client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
50
- else:
51
- print("[DEBUG] No API key found, running without OpenAI client", flush=True)
52
  except Exception as e:
53
  print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
 
54
 
55
  env = EmailTriageEnv(task=TASK_NAME)
56
 
@@ -71,7 +69,41 @@ def main():
71
  try:
72
  desc = state["description"]
73
 
74
- action_list = smart_agent_logic(desc)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
75
 
76
  state, reward, done, _, _ = env.step(action_list)
77
 
 
7
  from app import smart_agent_logic
8
 
9
 
10
+ # MUST use these EXACT env vars
11
+ API_BASE_URL = os.environ.get("API_BASE_URL")
12
+ API_KEY = os.environ.get("API_KEY")
13
+ MODEL_NAME = os.environ.get("MODEL_NAME", "Qwen/Qwen2.5-72B-Instruct")
14
 
15
  TASK_NAME = os.getenv("MY_ENV_V4_TASK", "easy")
16
  BENCHMARK = "email_triage_env"
 
43
 
44
 
45
  def main():
46
+ # ✅ REQUIRED: Initialize OpenAI client with provided proxy
 
47
  try:
48
+ client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
 
49
  except Exception as e:
50
  print(f"[DEBUG] OpenAI init failed: {e}", flush=True)
51
+ client = None
52
 
53
  env = EmailTriageEnv(task=TASK_NAME)
54
 
 
69
  try:
70
  desc = state["description"]
71
 
72
+ # 🔥 LLM CALL (MANDATORY)
73
+ action_list = None
74
+
75
+ if client:
76
+ try:
77
+ response = client.chat.completions.create(
78
+ model=MODEL_NAME,
79
+ messages=[
80
+ {
81
+ "role": "system",
82
+ "content": "Classify email into 3 integers: urgency (0-2), routing (0-2), resolution (0-2). Return only numbers like: 2 1 2"
83
+ },
84
+ {
85
+ "role": "user",
86
+ "content": desc
87
+ }
88
+ ],
89
+ max_tokens=20,
90
+ temperature=0,
91
+ )
92
+
93
+ text = response.choices[0].message.content.strip()
94
+
95
+ # Parse response
96
+ action_list = [int(x) for x in text.replace(",", " ").split()[:3]]
97
+
98
+ if len(action_list) != 3:
99
+ raise ValueError("Invalid LLM output")
100
+
101
+ except Exception as llm_error:
102
+ print(f"[DEBUG] LLM failed: {llm_error}", flush=True)
103
+
104
+ # ✅ fallback if LLM fails
105
+ if not action_list:
106
+ action_list = smart_agent_logic(desc)
107
 
108
  state, reward, done, _, _ = env.step(action_list)
109