molforge / inference.py
Adhitya122's picture
Prepare MolForge OpenEnv Docker Space submission
bf9e424 verified
"""Judge-facing baseline inference script for MolForge."""
from __future__ import annotations
import json
import os
from typing import Any, Optional, cast
from openai import OpenAI
from inference_common import (
COMPACT_SYSTEM_PROMPT,
SYSTEM_PROMPT,
build_model_payload,
extract_json,
)
try:
from molforge.models import MolForgeAction, MolForgeObservation
from molforge.server.molforge_environment import MolForgeEnvironment
except ImportError:
from models import MolForgeAction, MolForgeObservation
from server.molforge_environment import MolForgeEnvironment
API_BASE_URL = os.getenv("API_BASE_URL")
MODEL_NAME = os.getenv("MODEL_NAME")
API_KEY = os.getenv("API_KEY") or os.getenv("HF_TOKEN")
MAX_TURNS = 10
MODEL_TIMEOUT_S = float(os.getenv("MODEL_TIMEOUT_S", "35"))
MODEL_LONG_TIMEOUT_S = float(os.getenv("MODEL_LONG_TIMEOUT_S", "45"))
MODEL_RETRY_TIMEOUT_S = float(os.getenv("MODEL_RETRY_TIMEOUT_S", "15"))
MODEL_MAX_TOKENS = int(os.getenv("MODEL_MAX_TOKENS", "220"))
MIN_REPORTED_SCORE = 1e-6
MAX_REPORTED_SCORE = 1.0 - 1e-6
def main() -> None:
env = MolForgeEnvironment()
if not API_BASE_URL or not MODEL_NAME or not API_KEY:
raise RuntimeError(
"API_BASE_URL, MODEL_NAME, and API_KEY or HF_TOKEN are required. "
"No heuristic fallback is available."
)
client = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
scores = []
raw_final_scores = []
submission_scores = []
progress_scores = []
model_action_count = 0
for episode_index in range(3):
observation = env.reset()
task_name = observation.scenario_id
episode_error = ""
print(
f"[START] task={task_name} difficulty={observation.difficulty} episode={episode_index + 1}",
flush=True,
)
for _ in range(MAX_TURNS):
if observation.done:
break
try:
action = choose_action(client, observation)
model_action_count += 1
observation = env.step(action)
except Exception as exc:
episode_error = f"{exc.__class__.__name__}:{exc}"
print(
f"[STEP] task={task_name} step={observation.step_index + 1} "
f"reward=0.000000 action=model_error status=failed",
flush=True,
)
break
print(
f"[STEP] task={task_name} step={observation.step_index} "
f"reward={observation.reward:.6f} action={action.action_type} "
f"actor={action.acting_role} status={observation.governance.status}",
flush=True,
)
if observation.done:
break
grader_scores = observation.metadata.get("terminal_grader_scores", {})
raw_final_score = float(grader_scores.get("final_score", grader_scores.get("submission_score", 0.0)))
final_score = reportable_score(raw_final_score)
submission_score = float(grader_scores.get("submission_score", 0.0))
progress_score = float(grader_scores.get("progress_score", 0.0))
scores.append(final_score)
raw_final_scores.append(raw_final_score)
submission_scores.append(submission_score)
progress_scores.append(progress_score)
end_line = (
f"[END] task={task_name} score={final_score:.6f} raw_score={raw_final_score:.6f} "
f"submission_score={submission_score:.6f} progress_score={progress_score:.6f} "
f"steps={observation.step_index}"
)
if episode_error:
end_line += f" error={json.dumps(episode_error)}"
print(end_line, flush=True)
if observation.report_card:
print(observation.report_card, flush=True)
average = sum(scores) / len(scores)
average_progress = sum(progress_scores) / len(progress_scores)
summary = {
"scores": scores,
"raw_final_scores": raw_final_scores,
"average_final_score": round(reportable_score(average), 6),
"submission_scores": submission_scores,
"average_submission_score": round(sum(submission_scores) / len(submission_scores), 4),
"progress_scores": progress_scores,
"average_progress_score": round(average_progress, 4),
"model_action_count": model_action_count,
"model_name": MODEL_NAME,
"api_base_url": API_BASE_URL,
"fallback_enabled": False,
}
print("[SUMMARY] " + json.dumps(summary, separators=(",", ":")), flush=True)
def reportable_score(score: float) -> float:
"""Validator-facing scores must be strictly between 0 and 1."""
if score <= 0.0:
return MIN_REPORTED_SCORE
if score >= 1.0:
return MAX_REPORTED_SCORE
return score
def choose_action(client: OpenAI, observation: MolForgeObservation) -> MolForgeAction:
"""Use the model and fail loudly when it cannot produce a valid action."""
action, error = ask_model(client, observation)
if action is None:
raise RuntimeError(f"Model action failed: {error}")
return action
def ask_model(client: OpenAI, observation: MolForgeObservation) -> tuple[Optional[MolForgeAction], str]:
"""Request a structured team action from the model and parse it safely."""
errors = []
try:
full_payload = build_model_payload(observation, compact=False)
timeout_s = model_timeout_for_step(observation)
data = request_action_json(
client=client,
system_prompt=SYSTEM_PROMPT,
user_payload=full_payload,
timeout_s=timeout_s,
)
return MolForgeAction(**data), ""
except Exception as exc:
errors.append(f"full_prompt:{exc.__class__.__name__}:{exc}")
try:
compact_payload = build_model_payload(observation, compact=True)
data = request_action_json(
client=client,
system_prompt=COMPACT_SYSTEM_PROMPT,
user_payload=compact_payload,
timeout_s=MODEL_RETRY_TIMEOUT_S,
)
return MolForgeAction(**data), ""
except Exception as retry_exc:
errors.append(f"compact_prompt:{retry_exc.__class__.__name__}:{retry_exc}")
return None, " | ".join(errors)
def request_action_json(
*,
client: OpenAI,
system_prompt: str,
user_payload: dict[str, Any],
timeout_s: float,
) -> dict[str, Any]:
"""Call the remote model with a bounded timeout and parse a JSON action."""
configured_client = client.with_options(timeout=timeout_s)
completion = configured_client.chat.completions.create(
model=MODEL_NAME,
temperature=0.0,
max_tokens=MODEL_MAX_TOKENS,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": json.dumps(user_payload, indent=2)},
],
)
message_content = completion.choices[0].message.content
if isinstance(message_content, list):
text = "".join(part.get("text", "") for part in cast(list[dict[str, Any]], message_content))
else:
text = message_content or ""
return extract_json(text)
def model_timeout_for_step(observation: MolForgeObservation) -> float:
"""Allow more time for high-value late-stage decisions without making every step unbounded."""
if observation.difficulty == "hard":
return MODEL_LONG_TIMEOUT_S
if observation.step_index >= observation.max_steps - 2:
return MODEL_LONG_TIMEOUT_S
return MODEL_TIMEOUT_S
if __name__ == "__main__":
main()