deepbattler / RL /infer_battleground_cloud.py
lbtwyk
Update README to focus on RL training pipeline
fed1ca7
#!/usr/bin/env python
# infer_battleground_cloud.py
#
# Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face.
#
# Backends supported:
# 1. Hugging Face Space exposing /generate_actions (preferred for this project)
# 2. Hugging Face Inference Endpoint / Hosted model via InferenceClient
#
# Usage examples:
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
# --input RL/datasets/game_history_2_flat.json \
# --output RL/datasets/game_history_2_actions.jsonl \
# --model-id iteratehack/deepbattler-battleground-gamehistory
#
# or, if you deploy a dedicated Inference Endpoint:
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
# --input RL/datasets/game_history_2_flat.json \
# --output RL/datasets/game_history_2_actions.jsonl \
# --endpoint https://<your-endpoint>.inference.huggingface.cloud
#
# The script expects the same "state" structure and action JSON schema as
# train_battleground_rlaif_gamehistory.py.
import argparse
import json
from pathlib import Path
from typing import Any, Dict, List, Optional
import requests
from huggingface_hub import InferenceClient
from RL.battleground_nl_utils import game_state_to_natural_language
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
Given the current game state as a JSON object, choose the best full-turn sequence
of actions and respond with a single JSON object in this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the game state JSON:
"""
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
Given the following natural language description of the current game state, choose
the best full-turn sequence of actions and respond with a single JSON object in
this exact format:
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
Rules:
1. Respond with JSON only. Do not add explanations or any extra text.
2. The top-level object must have exactly one key: "actions".
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
atomic action objects.
4. Use 0-based integers for indices or null when not used.
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
6. "card_name" must exactly match a card name from the game state when required,
otherwise null.
Now here is the description of the game state:
"""
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
"""Build a prompt from a flattened game_history-style example.
This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that
the inference distribution matches training.
The example should have:
- phase: string (e.g., "PlayerTurn")
- turn: int
- state: nested dict with keys: game_state, player_hero, resources, board_state
"""
state = example.get("state", {}) or {}
if input_mode == "nl":
nl_state = game_state_to_natural_language(state)
prefix = INSTRUCTION_PREFIX_NL
state_text = nl_state
else:
gs = state.get("game_state", {}) or {}
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
turn = example.get("turn", gs.get("turn_number", 0))
obj = {
"task": "battlegrounds_policy_v1",
"phase": phase,
"turn": turn,
"state": state,
}
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
prefix = INSTRUCTION_PREFIX
return prefix + "\n" + state_text
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
"""Parse a model completion into a list of atomic action dicts.
Expected formats (same as training reward parser):
- {"actions": [ {...}, {...}, ... ]}
- {"action": [ {...}, {...}, ... ]} # tolerated fallback
"""
text = text.strip()
start_idx = text.find("{")
if start_idx == -1:
return None
end_idx = text.rfind("}")
if end_idx == -1:
return None
json_str = text[start_idx : end_idx + 1]
try:
obj = json.loads(json_str)
except Exception:
return None
if not isinstance(obj, dict):
return None
seq = None
if "actions" in obj:
if isinstance(obj["actions"], list):
seq = obj["actions"]
elif isinstance(obj["actions"], dict):
seq = [obj["actions"]]
elif "action" in obj:
if isinstance(obj["action"], list):
seq = obj["action"]
elif isinstance(obj["action"], dict):
seq = [obj["action"]]
if seq is None:
return None
actions: List[Dict[str, Any]] = []
for step in seq:
if not isinstance(step, dict):
return None
actions.append(step)
return actions
def run_inference_via_client(
client: InferenceClient,
examples: List[Dict[str, Any]],
input_mode: str = "json",
max_new_tokens: int = 256,
temperature: float = 0.2,
) -> List[Dict[str, Any]]:
"""Run inference over a list of examples and return enriched records.
Each output row is the original example plus:
- actions: parsed list of atomic action dicts (or None on parse failure)
- raw_completion: raw text returned by the model
"""
results: List[Dict[str, Any]] = []
for ex in examples:
prompt = build_prompt(ex, input_mode=input_mode)
completion = client.text_generation(
prompt,
max_new_tokens=max_new_tokens,
temperature=temperature,
)
actions = parse_actions_from_completion(completion)
out_row = dict(ex)
out_row["raw_completion"] = completion
out_row["actions"] = actions
results.append(out_row)
return results
def run_inference_via_space(
space_url: str,
examples: List[Dict[str, Any]],
max_new_tokens: int = 256,
temperature: float = 0.2,
timeout: int = 120,
hf_token: Optional[str] = None,
) -> List[Dict[str, Any]]:
"""Call the deployed Space /generate_actions endpoint for each example."""
base_url = space_url.rstrip("/")
endpoint = f"{base_url}/generate_actions"
headers = {"Content-Type": "application/json"}
if hf_token:
headers["Authorization"] = f"Bearer {hf_token}"
results: List[Dict[str, Any]] = []
for ex in examples:
payload = {
"phase": ex.get("phase"),
"turn": ex.get("turn"),
"state": ex.get("state", {}),
"max_new_tokens": max_new_tokens,
"temperature": temperature,
}
resp = requests.post(endpoint, json=payload, headers=headers, timeout=timeout)
resp.raise_for_status()
data = resp.json()
out_row = dict(ex)
out_row["actions"] = data.get("actions")
out_row["raw_completion"] = data.get("raw_completion")
results.append(out_row)
return results
def load_examples(path: str) -> List[Dict[str, Any]]:
p = Path(path)
if not p.exists():
raise FileNotFoundError(path)
with p.open("r", encoding="utf-8") as f:
data = json.load(f)
if not isinstance(data, list):
raise ValueError("Expected input JSON to be a list of examples (flat rows)")
return data
def save_results(path: str, rows: List[Dict[str, Any]]) -> None:
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
with p.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False) + "\n")
def parse_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.",
)
parser.add_argument(
"--input",
required=True,
help="Path to input JSON file (list of flattened game_history rows).",
)
parser.add_argument(
"--output",
required=True,
help="Path to output JSONL file with actions and raw completions.",
)
parser.add_argument(
"--space-url",
default=None,
help=(
"URL of the Hugging Face Space hosting /generate_actions (e.g. "
"https://iteratehack-deepbattler.hf.space). If provided, the script calls "
"that endpoint instead of the Inference API."
),
)
parser.add_argument(
"--model-id",
default=None,
help=(
"Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). "
"Used only if --space-url is omitted."
),
)
parser.add_argument(
"--endpoint",
default=None,
help=(
"Full URL of a dedicated Inference Endpoint. If provided (and --space-url missing), "
"this takes precedence over --model-id."
),
)
parser.add_argument(
"--hf-token",
default=None,
help=(
"Hugging Face access token. Needed for private Spaces/models. If omitted, use the token "
"from `huggingface-cli login` or HF_TOKEN env var."
),
)
parser.add_argument(
"--input-mode",
choices=["json", "nl"],
default="json",
help="Match the input_mode used during training (json or nl).",
)
parser.add_argument("--max-new-tokens", type=int, default=256)
parser.add_argument("--temperature", type=float, default=0.2)
parser.add_argument(
"--request-timeout",
type=int,
default=120,
help="Timeout (seconds) for HTTP requests when using --space-url",
)
parser.add_argument(
"--print-results",
action="store_true",
help="Print each output row (JSON) to stdout after inference.",
)
args = parser.parse_args()
if not any([args.space_url, args.endpoint, args.model_id]):
parser.error("Provide --space-url, --endpoint, or --model-id")
return args
def main() -> None:
args = parse_args()
examples = load_examples(args.input)
if args.space_url:
results = run_inference_via_space(
args.space_url,
examples,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
timeout=args.request_timeout,
hf_token=args.hf_token,
)
else:
if args.endpoint:
client = InferenceClient(args.endpoint, token=args.hf_token)
else:
client = InferenceClient(args.model_id, token=args.hf_token)
results = run_inference_via_client(
client,
examples,
input_mode=args.input_mode,
max_new_tokens=args.max_new_tokens,
temperature=args.temperature,
)
save_results(args.output, results)
print(f"Wrote {len(results)} rows to {args.output}")
if args.print_results:
for row in results:
print(json.dumps(row, ensure_ascii=False))
if __name__ == "__main__":
main()