Spaces:
Sleeping
Sleeping
lbtwyk
commited on
Commit
·
e796c83
1
Parent(s):
8036a6e
Add HF Space GPU inference for Battlegrounds Qwen
Browse files- RL/infer_battleground_cloud.py +277 -0
- app.py +206 -1
- requirements.txt +2 -0
RL/infer_battleground_cloud.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
# infer_battleground_cloud.py
|
| 3 |
+
#
|
| 4 |
+
# Cloud-based inference script for a fine-tuned Battlegrounds Qwen model hosted on Hugging Face.
|
| 5 |
+
#
|
| 6 |
+
# Usage examples:
|
| 7 |
+
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
|
| 8 |
+
# --input RL/datasets/game_history_2_flat.json \
|
| 9 |
+
# --output RL/datasets/game_history_2_actions.jsonl \
|
| 10 |
+
# --model-id iteratehack/deepbattler-battleground-gamehistory
|
| 11 |
+
#
|
| 12 |
+
# or, if you deploy a dedicated Inference Endpoint:
|
| 13 |
+
# PYTHONPATH=. python RL/infer_battleground_cloud.py \
|
| 14 |
+
# --input RL/datasets/game_history_2_flat.json \
|
| 15 |
+
# --output RL/datasets/game_history_2_actions.jsonl \
|
| 16 |
+
# --endpoint https://<your-endpoint>.inference.huggingface.cloud
|
| 17 |
+
#
|
| 18 |
+
# The script expects the same "state" structure and action JSON schema as
|
| 19 |
+
# train_battleground_rlaif_gamehistory.py.
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from typing import Any, Dict, List, Optional
|
| 25 |
+
|
| 26 |
+
from huggingface_hub import InferenceClient
|
| 27 |
+
|
| 28 |
+
from RL.battleground_nl_utils import game_state_to_natural_language
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
|
| 32 |
+
Given the current game state as a JSON object, choose the best full-turn sequence
|
| 33 |
+
of actions and respond with a single JSON object in this exact format:
|
| 34 |
+
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
|
| 35 |
+
Rules:
|
| 36 |
+
1. Respond with JSON only. Do not add explanations or any extra text.
|
| 37 |
+
2. The top-level object must have exactly one key: "actions".
|
| 38 |
+
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
|
| 39 |
+
atomic action objects.
|
| 40 |
+
4. Use 0-based integers for indices or null when not used.
|
| 41 |
+
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
|
| 42 |
+
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
|
| 43 |
+
6. "card_name" must exactly match a card name from the game state when required,
|
| 44 |
+
otherwise null.
|
| 45 |
+
Now here is the game state JSON:
|
| 46 |
+
"""
|
| 47 |
+
|
| 48 |
+
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
|
| 49 |
+
Given the following natural language description of the current game state, choose
|
| 50 |
+
the best full-turn sequence of actions and respond with a single JSON object in
|
| 51 |
+
this exact format:
|
| 52 |
+
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
|
| 53 |
+
Rules:
|
| 54 |
+
1. Respond with JSON only. Do not add explanations or any extra text.
|
| 55 |
+
2. The top-level object must have exactly one key: "actions".
|
| 56 |
+
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
|
| 57 |
+
atomic action objects.
|
| 58 |
+
4. Use 0-based integers for indices or null when not used.
|
| 59 |
+
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
|
| 60 |
+
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
|
| 61 |
+
6. "card_name" must exactly match a card name from the game state when required,
|
| 62 |
+
otherwise null.
|
| 63 |
+
Now here is the description of the game state:
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
|
| 68 |
+
"""Build a prompt from a flattened game_history-style example.
|
| 69 |
+
|
| 70 |
+
This mirrors _build_prompt in train_battleground_rlaif_gamehistory.py so that
|
| 71 |
+
the inference distribution matches training.
|
| 72 |
+
|
| 73 |
+
The example should have:
|
| 74 |
+
- phase: string (e.g., "PlayerTurn")
|
| 75 |
+
- turn: int
|
| 76 |
+
- state: nested dict with keys: game_state, player_hero, resources, board_state
|
| 77 |
+
"""
|
| 78 |
+
|
| 79 |
+
state = example.get("state", {}) or {}
|
| 80 |
+
|
| 81 |
+
if input_mode == "nl":
|
| 82 |
+
nl_state = game_state_to_natural_language(state)
|
| 83 |
+
prefix = INSTRUCTION_PREFIX_NL
|
| 84 |
+
state_text = nl_state
|
| 85 |
+
else:
|
| 86 |
+
gs = state.get("game_state", {}) or {}
|
| 87 |
+
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
|
| 88 |
+
turn = example.get("turn", gs.get("turn_number", 0))
|
| 89 |
+
obj = {
|
| 90 |
+
"task": "battlegrounds_policy_v1",
|
| 91 |
+
"phase": phase,
|
| 92 |
+
"turn": turn,
|
| 93 |
+
"state": state,
|
| 94 |
+
}
|
| 95 |
+
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
|
| 96 |
+
prefix = INSTRUCTION_PREFIX
|
| 97 |
+
|
| 98 |
+
return prefix + "\n" + state_text
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
|
| 102 |
+
"""Parse a model completion into a list of atomic action dicts.
|
| 103 |
+
|
| 104 |
+
Expected formats (same as training reward parser):
|
| 105 |
+
- {"actions": [ {...}, {...}, ... ]}
|
| 106 |
+
- {"action": [ {...}, {...}, ... ]} # tolerated fallback
|
| 107 |
+
"""
|
| 108 |
+
|
| 109 |
+
text = text.strip()
|
| 110 |
+
start_idx = text.find("{")
|
| 111 |
+
if start_idx == -1:
|
| 112 |
+
return None
|
| 113 |
+
end_idx = text.rfind("}")
|
| 114 |
+
if end_idx == -1:
|
| 115 |
+
return None
|
| 116 |
+
json_str = text[start_idx : end_idx + 1]
|
| 117 |
+
|
| 118 |
+
try:
|
| 119 |
+
obj = json.loads(json_str)
|
| 120 |
+
except Exception:
|
| 121 |
+
return None
|
| 122 |
+
|
| 123 |
+
if not isinstance(obj, dict):
|
| 124 |
+
return None
|
| 125 |
+
|
| 126 |
+
seq = None
|
| 127 |
+
if "actions" in obj:
|
| 128 |
+
if isinstance(obj["actions"], list):
|
| 129 |
+
seq = obj["actions"]
|
| 130 |
+
elif isinstance(obj["actions"], dict):
|
| 131 |
+
seq = [obj["actions"]]
|
| 132 |
+
elif "action" in obj:
|
| 133 |
+
if isinstance(obj["action"], list):
|
| 134 |
+
seq = obj["action"]
|
| 135 |
+
elif isinstance(obj["action"], dict):
|
| 136 |
+
seq = [obj["action"]]
|
| 137 |
+
|
| 138 |
+
if seq is None:
|
| 139 |
+
return None
|
| 140 |
+
|
| 141 |
+
actions: List[Dict[str, Any]] = []
|
| 142 |
+
for step in seq:
|
| 143 |
+
if not isinstance(step, dict):
|
| 144 |
+
return None
|
| 145 |
+
actions.append(step)
|
| 146 |
+
return actions
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def run_inference(
|
| 150 |
+
client: InferenceClient,
|
| 151 |
+
examples: List[Dict[str, Any]],
|
| 152 |
+
input_mode: str = "json",
|
| 153 |
+
max_new_tokens: int = 256,
|
| 154 |
+
temperature: float = 0.2,
|
| 155 |
+
) -> List[Dict[str, Any]]:
|
| 156 |
+
"""Run inference over a list of examples and return enriched records.
|
| 157 |
+
|
| 158 |
+
Each output row is the original example plus:
|
| 159 |
+
- actions: parsed list of atomic action dicts (or None on parse failure)
|
| 160 |
+
- raw_completion: raw text returned by the model
|
| 161 |
+
"""
|
| 162 |
+
|
| 163 |
+
results: List[Dict[str, Any]] = []
|
| 164 |
+
for ex in examples:
|
| 165 |
+
prompt = build_prompt(ex, input_mode=input_mode)
|
| 166 |
+
completion = client.text_generation(
|
| 167 |
+
prompt,
|
| 168 |
+
max_new_tokens=max_new_tokens,
|
| 169 |
+
temperature=temperature,
|
| 170 |
+
)
|
| 171 |
+
actions = parse_actions_from_completion(completion)
|
| 172 |
+
|
| 173 |
+
out_row = dict(ex)
|
| 174 |
+
out_row["raw_completion"] = completion
|
| 175 |
+
out_row["actions"] = actions
|
| 176 |
+
results.append(out_row)
|
| 177 |
+
|
| 178 |
+
return results
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def load_examples(path: str) -> List[Dict[str, Any]]:
|
| 182 |
+
p = Path(path)
|
| 183 |
+
if not p.exists():
|
| 184 |
+
raise FileNotFoundError(path)
|
| 185 |
+
|
| 186 |
+
with p.open("r", encoding="utf-8") as f:
|
| 187 |
+
data = json.load(f)
|
| 188 |
+
|
| 189 |
+
if not isinstance(data, list):
|
| 190 |
+
raise ValueError("Expected input JSON to be a list of examples (flat rows)")
|
| 191 |
+
return data
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
def save_results(path: str, rows: List[Dict[str, Any]]) -> None:
|
| 195 |
+
p = Path(path)
|
| 196 |
+
p.parent.mkdir(parents=True, exist_ok=True)
|
| 197 |
+
with p.open("w", encoding="utf-8") as f:
|
| 198 |
+
for row in rows:
|
| 199 |
+
f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
def parse_args() -> argparse.Namespace:
|
| 203 |
+
parser = argparse.ArgumentParser(
|
| 204 |
+
description="Run cloud inference for Battlegrounds Qwen model via Hugging Face.",
|
| 205 |
+
)
|
| 206 |
+
parser.add_argument(
|
| 207 |
+
"--input",
|
| 208 |
+
required=True,
|
| 209 |
+
help="Path to input JSON file (list of flattened game_history rows).",
|
| 210 |
+
)
|
| 211 |
+
parser.add_argument(
|
| 212 |
+
"--output",
|
| 213 |
+
required=True,
|
| 214 |
+
help="Path to output JSONL file with actions and raw completions.",
|
| 215 |
+
)
|
| 216 |
+
parser.add_argument(
|
| 217 |
+
"--model-id",
|
| 218 |
+
default=None,
|
| 219 |
+
help=(
|
| 220 |
+
"Hugging Face model repo id (e.g. iteratehack/deepbattler-battleground-gamehistory). "
|
| 221 |
+
"If provided, serverless / hosted inference will be used."
|
| 222 |
+
),
|
| 223 |
+
)
|
| 224 |
+
parser.add_argument(
|
| 225 |
+
"--endpoint",
|
| 226 |
+
default=None,
|
| 227 |
+
help=(
|
| 228 |
+
"Full URL of a dedicated Inference Endpoint. If provided, this takes precedence "
|
| 229 |
+
"over --model-id."
|
| 230 |
+
),
|
| 231 |
+
)
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--hf-token",
|
| 234 |
+
default=None,
|
| 235 |
+
help=(
|
| 236 |
+
"Hugging Face access token. If omitted, the token from `huggingface-cli login` "
|
| 237 |
+
"or HF_TOKEN env var will be used."
|
| 238 |
+
),
|
| 239 |
+
)
|
| 240 |
+
parser.add_argument(
|
| 241 |
+
"--input-mode",
|
| 242 |
+
choices=["json", "nl"],
|
| 243 |
+
default="json",
|
| 244 |
+
help="Match the input_mode used during training (json or nl).",
|
| 245 |
+
)
|
| 246 |
+
parser.add_argument("--max-new-tokens", type=int, default=256)
|
| 247 |
+
parser.add_argument("--temperature", type=float, default=0.2)
|
| 248 |
+
|
| 249 |
+
args = parser.parse_args()
|
| 250 |
+
if not args.model_id and not args.endpoint:
|
| 251 |
+
parser.error("You must provide either --model-id or --endpoint")
|
| 252 |
+
|
| 253 |
+
return args
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def main() -> None:
|
| 257 |
+
args = parse_args()
|
| 258 |
+
|
| 259 |
+
if args.endpoint:
|
| 260 |
+
client = InferenceClient(args.endpoint, token=args.hf_token)
|
| 261 |
+
else:
|
| 262 |
+
client = InferenceClient(args.model_id, token=args.hf_token)
|
| 263 |
+
|
| 264 |
+
examples = load_examples(args.input)
|
| 265 |
+
results = run_inference(
|
| 266 |
+
client,
|
| 267 |
+
examples,
|
| 268 |
+
input_mode=args.input_mode,
|
| 269 |
+
max_new_tokens=args.max_new_tokens,
|
| 270 |
+
temperature=args.temperature,
|
| 271 |
+
)
|
| 272 |
+
save_results(args.output, results)
|
| 273 |
+
print(f"Wrote {len(results)} rows to {args.output}")
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
if __name__ == "__main__":
|
| 277 |
+
main()
|
app.py
CHANGED
|
@@ -1,7 +1,212 @@
|
|
| 1 |
from fastapi import FastAPI
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
|
| 3 |
app = FastAPI()
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
@app.get("/")
|
| 6 |
def root():
|
| 7 |
-
return {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from fastapi import FastAPI
|
| 2 |
+
from pydantic import BaseModel
|
| 3 |
+
from typing import Any, Dict, List, Optional
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import torch
|
| 7 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 8 |
+
from peft import PeftModel
|
| 9 |
+
|
| 10 |
+
from RL.battleground_nl_utils import game_state_to_natural_language
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
BASE_MODEL_ID = "Qwen/Qwen3-4B-Instruct-2507"
|
| 14 |
+
ADAPTER_MODEL_ID = "iteratehack/battleground-rlaif-qwen-gamehistory-grpo"
|
| 15 |
+
DEFAULT_MAX_NEW_TOKENS = 256
|
| 16 |
+
DEFAULT_TEMPERATURE = 0.2
|
| 17 |
+
|
| 18 |
|
| 19 |
app = FastAPI()
|
| 20 |
|
| 21 |
+
tokenizer: Optional[AutoTokenizer] = None
|
| 22 |
+
model = None
|
| 23 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
INSTRUCTION_PREFIX = """You are a Hearthstone Battlegrounds AI.
|
| 27 |
+
Given the current game state as a JSON object, choose the best full-turn sequence
|
| 28 |
+
of actions and respond with a single JSON object in this exact format:
|
| 29 |
+
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
|
| 30 |
+
Rules:
|
| 31 |
+
1. Respond with JSON only. Do not add explanations or any extra text.
|
| 32 |
+
2. The top-level object must have exactly one key: "actions".
|
| 33 |
+
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
|
| 34 |
+
atomic action objects.
|
| 35 |
+
4. Use 0-based integers for indices or null when not used.
|
| 36 |
+
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
|
| 37 |
+
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
|
| 38 |
+
6. "card_name" must exactly match a card name from the game state when required,
|
| 39 |
+
otherwise null.
|
| 40 |
+
Now here is the game state JSON:
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
INSTRUCTION_PREFIX_NL = """You are a Hearthstone Battlegrounds AI.
|
| 44 |
+
Given the following natural language description of the current game state, choose
|
| 45 |
+
the best full-turn sequence of actions and respond with a single JSON object in
|
| 46 |
+
this exact format:
|
| 47 |
+
{"actions":[{"type":"<ACTION_TYPE>","tavern_index":<int-or-null>,"hand_index":<int-or-null>,"board_index":<int-or-null>,"card_name":<string-or-null>}, ...]}
|
| 48 |
+
Rules:
|
| 49 |
+
1. Respond with JSON only. Do not add explanations or any extra text.
|
| 50 |
+
2. The top-level object must have exactly one key: "actions".
|
| 51 |
+
3. "actions" must be a JSON array (possibly empty, but usually 1+ steps) of
|
| 52 |
+
atomic action objects.
|
| 53 |
+
4. Use 0-based integers for indices or null when not used.
|
| 54 |
+
5. "type" must be one of: "BUY_FROM_TAVERN","PLAY_FROM_HAND","SELL_FROM_BOARD",
|
| 55 |
+
"HERO_POWER","ROLL","UPGRADE_TAVERN","FREEZE","END_TURN".
|
| 56 |
+
6. "card_name" must exactly match a card name from the game state when required,
|
| 57 |
+
otherwise null.
|
| 58 |
+
Now here is the description of the game state:
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
class GenerateRequest(BaseModel):
|
| 63 |
+
phase: Optional[str] = None
|
| 64 |
+
turn: Optional[int] = None
|
| 65 |
+
state: Dict[str, Any]
|
| 66 |
+
input_mode: str = "json" # "json" or "nl"
|
| 67 |
+
max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS
|
| 68 |
+
temperature: float = DEFAULT_TEMPERATURE
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def build_prompt(example: Dict[str, Any], input_mode: str = "json") -> str:
|
| 72 |
+
state = example.get("state", {}) or {}
|
| 73 |
+
|
| 74 |
+
if input_mode == "nl":
|
| 75 |
+
nl_state = game_state_to_natural_language(state)
|
| 76 |
+
prefix = INSTRUCTION_PREFIX_NL
|
| 77 |
+
state_text = nl_state
|
| 78 |
+
else:
|
| 79 |
+
gs = state.get("game_state", {}) or {}
|
| 80 |
+
phase = example.get("phase", gs.get("phase", "PlayerTurn"))
|
| 81 |
+
turn = example.get("turn", gs.get("turn_number", 0))
|
| 82 |
+
obj = {
|
| 83 |
+
"task": "battlegrounds_policy_v1",
|
| 84 |
+
"phase": phase,
|
| 85 |
+
"turn": turn,
|
| 86 |
+
"state": state,
|
| 87 |
+
}
|
| 88 |
+
state_text = json.dumps(obj, separators=(",", ":"), ensure_ascii=False)
|
| 89 |
+
prefix = INSTRUCTION_PREFIX
|
| 90 |
+
|
| 91 |
+
return prefix + "\n" + state_text
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def parse_actions_from_completion(text: str) -> Optional[List[Dict[str, Any]]]:
|
| 95 |
+
text = text.strip()
|
| 96 |
+
start_idx = text.find("{")
|
| 97 |
+
if start_idx == -1:
|
| 98 |
+
return None
|
| 99 |
+
end_idx = text.rfind("}")
|
| 100 |
+
if end_idx == -1:
|
| 101 |
+
return None
|
| 102 |
+
json_str = text[start_idx : end_idx + 1]
|
| 103 |
+
|
| 104 |
+
try:
|
| 105 |
+
obj = json.loads(json_str)
|
| 106 |
+
except Exception:
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
if not isinstance(obj, dict):
|
| 110 |
+
return None
|
| 111 |
+
|
| 112 |
+
seq = None
|
| 113 |
+
if "actions" in obj:
|
| 114 |
+
if isinstance(obj["actions"], list):
|
| 115 |
+
seq = obj["actions"]
|
| 116 |
+
elif isinstance(obj["actions"], dict):
|
| 117 |
+
seq = [obj["actions"]]
|
| 118 |
+
elif "action" in obj:
|
| 119 |
+
if isinstance(obj["action"], list):
|
| 120 |
+
seq = obj["action"]
|
| 121 |
+
elif isinstance(obj["action"], dict):
|
| 122 |
+
seq = [obj["action"]]
|
| 123 |
+
|
| 124 |
+
if seq is None:
|
| 125 |
+
return None
|
| 126 |
+
|
| 127 |
+
actions: List[Dict[str, Any]] = []
|
| 128 |
+
for step in seq:
|
| 129 |
+
if not isinstance(step, dict):
|
| 130 |
+
return None
|
| 131 |
+
actions.append(step)
|
| 132 |
+
return actions
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def load_model() -> None:
|
| 136 |
+
global tokenizer, model
|
| 137 |
+
if tokenizer is not None and model is not None:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
tok = AutoTokenizer.from_pretrained(ADAPTER_MODEL_ID, trust_remote_code=True)
|
| 141 |
+
if tok.pad_token is None:
|
| 142 |
+
tok.pad_token = tok.eos_token
|
| 143 |
+
tok.padding_side = "left"
|
| 144 |
+
|
| 145 |
+
if torch.cuda.is_available():
|
| 146 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 147 |
+
BASE_MODEL_ID,
|
| 148 |
+
device_map="auto",
|
| 149 |
+
torch_dtype=torch.bfloat16,
|
| 150 |
+
trust_remote_code=True,
|
| 151 |
+
)
|
| 152 |
+
else:
|
| 153 |
+
base = AutoModelForCausalLM.from_pretrained(
|
| 154 |
+
BASE_MODEL_ID,
|
| 155 |
+
torch_dtype=torch.float32,
|
| 156 |
+
trust_remote_code=True,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
peft_model = PeftModel.from_pretrained(base, ADAPTER_MODEL_ID)
|
| 160 |
+
if not torch.cuda.is_available():
|
| 161 |
+
peft_model.to(device)
|
| 162 |
+
peft_model.eval()
|
| 163 |
+
|
| 164 |
+
tokenizer = tok
|
| 165 |
+
model = peft_model
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
@app.on_event("startup")
|
| 169 |
+
async def _startup_event() -> None:
|
| 170 |
+
load_model()
|
| 171 |
+
|
| 172 |
+
|
| 173 |
@app.get("/")
|
| 174 |
def root():
|
| 175 |
+
return {
|
| 176 |
+
"status": "ok",
|
| 177 |
+
"message": "DeepBattler Battlegrounds Space is running",
|
| 178 |
+
"base_model": BASE_MODEL_ID,
|
| 179 |
+
"adapter_model": ADAPTER_MODEL_ID,
|
| 180 |
+
}
|
| 181 |
+
|
| 182 |
+
|
| 183 |
+
@app.post("/generate_actions")
|
| 184 |
+
def generate_actions(req: GenerateRequest):
|
| 185 |
+
load_model()
|
| 186 |
+
|
| 187 |
+
example = {
|
| 188 |
+
"phase": req.phase,
|
| 189 |
+
"turn": req.turn,
|
| 190 |
+
"state": req.state,
|
| 191 |
+
}
|
| 192 |
+
prompt = build_prompt(example, input_mode=req.input_mode)
|
| 193 |
+
|
| 194 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
| 195 |
+
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 196 |
+
|
| 197 |
+
with torch.no_grad():
|
| 198 |
+
output_ids = model.generate(
|
| 199 |
+
**inputs,
|
| 200 |
+
max_new_tokens=req.max_new_tokens,
|
| 201 |
+
do_sample=True,
|
| 202 |
+
temperature=req.temperature,
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
generated_ids = output_ids[0, inputs["input_ids"].shape[1] :]
|
| 206 |
+
completion = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 207 |
+
actions = parse_actions_from_completion(completion)
|
| 208 |
+
|
| 209 |
+
return {
|
| 210 |
+
"actions": actions,
|
| 211 |
+
"raw_completion": completion,
|
| 212 |
+
}
|
requirements.txt
CHANGED
|
@@ -8,3 +8,5 @@ modelscope>=1.15.0
|
|
| 8 |
datasets>=2.19.0
|
| 9 |
peft>=0.11.1
|
| 10 |
trl>=0.25.1
|
|
|
|
|
|
|
|
|
| 8 |
datasets>=2.19.0
|
| 9 |
peft>=0.11.1
|
| 10 |
trl>=0.25.1
|
| 11 |
+
fastapi>=0.115.0
|
| 12 |
+
uvicorn[standard]>=0.30.0
|