LovecaSim / engine /tests /ability_test_helper.py
trioskosmos's picture
Upload folder using huggingface_hub
bb3fbf9 verified
import json
import os
import sys
from typing import List
# Ensure project root is in path
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), "../../"))
if PROJECT_ROOT not in sys.path:
sys.path.insert(0, PROJECT_ROOT)
try:
import engine_rust
except ImportError:
# Try importing from backend if not in root (common in some envs)
try:
from backend import engine_rust
except ImportError:
raise ImportError("Could not import engine_rust. Make sure the Rust extension is built.")
from engine.game.enums import Phase
class AbilityTestContext:
"""
Helper for writing Rust engine tests.
Provides a high-level API for state setup, action execution, and verification.
"""
def __init__(self, compiled_data_path: str = "data/cards_compiled.json"):
if not os.path.exists(compiled_data_path):
# Try alternative path for test execution environments
alt_path = os.path.join(PROJECT_ROOT, compiled_data_path)
if os.path.exists(alt_path):
compiled_data_path = alt_path
else:
raise FileNotFoundError(f"Compiled data not found: {compiled_data_path}")
with open(compiled_data_path, "r", encoding="utf-8") as f:
self.json_data = f.read()
self.db_raw = json.loads(self.json_data)
self.db = engine_rust.PyCardDatabase(self.json_data)
self.gs = engine_rust.PyGameState(self.db)
self.BASE_ID_MASK = 0xFFFFF
def mk_uid(self, base_id: int, instance_idx: int) -> int:
"""Create a Unique ID from base ID and instance index."""
return base_id | (instance_idx << 20)
def find_card_id(self, card_no: str, db_type: str = None) -> int:
"""Find the internal ID for a card number."""
dbs = [db_type] if db_type else ["member_db", "live_db", "energy_db"]
for db_name in dbs:
for cid, card in self.db_raw.get(db_name, {}).items():
if card.get("card_no") == card_no:
return int(cid)
raise ValueError(f"Card {card_no} not found in {dbs}")
def setup_game(self, p0_deck_nos: List[str] = None, p1_deck_nos: List[str] = None):
"""Initialize the game with specific decks (card numbers)."""
def nos_to_uids(nos, offset=0):
if not nos:
return [self.mk_uid(1, i + offset) for i in range(40)] # Default dummy deck
uids = []
counts = {}
for no in nos:
base = self.find_card_id(no)
count = counts.get(base, 0)
uids.append(self.mk_uid(base, count + offset))
counts[base] = count + 1
return uids
p0_main = nos_to_uids(p0_deck_nos, 0)
p1_main = nos_to_uids(p1_deck_nos, 1000) # Offset instance IDs for P1
# Default Energy and Lives
p0_energy = [self.mk_uid(40001, i) for i in range(10)]
p1_energy = [self.mk_uid(40001, 100 + i) for i in range(10)]
p0_lives = [self.mk_uid(1, 200 + i) for i in range(3)]
p1_lives = [self.mk_uid(1, 300 + i) for i in range(3)]
self.gs.initialize_game(p0_main, p1_main, p0_energy, p1_energy, p0_lives, p1_lives)
def skip_mulligan(self):
"""Skip mulligan phases for both players."""
if self.gs.phase == -1: # MULLIGAN_P1
self.gs.step(0)
if self.gs.phase == 0: # MULLIGAN_P2
self.gs.step(0)
def reach_main_phase(self):
"""Advance through Active, Energy, Draw phases to reach Main Phase."""
self.skip_mulligan()
steps = 0
while int(self.gs.phase) < int(Phase.MAIN) and steps < 20:
self.gs.step(0) # Pass/End Phase
steps += 1
def set_hand(self, player_idx: int, card_nos: List[str]):
"""Directly set a player's hand."""
uids = []
for i, no in enumerate(card_nos):
base = self.find_card_id(no)
uids.append(self.mk_uid(base, 500 + i)) # Instance ID 500+
self.gs.set_hand_cards(player_idx, uids)
def set_energy(self, player_idx: int, count: int, tapped_count: int = 0):
"""Directly set a player's energy zone."""
p = self.gs.get_player(player_idx)
p.energy_zone = [self.mk_uid(40001, 600 + i) for i in range(count)] # Instance ID 600+
p.tapped_energy = [True] * tapped_count + [False] * (count - tapped_count)
self.gs.set_player(player_idx, p)
def play_member(self, hand_idx: int, slot_idx: int):
"""Play a member from hand to a specific slot."""
# Action ID: 1 + hand_idx * 3 + slot_idx
action = 1 + hand_idx * 3 + slot_idx
self.gs.step(action)
def get_legal_actions(self) -> List[int]:
"""Get the list of legal action IDs."""
return list(self.gs.get_legal_action_ids())
def assert_phase(self, expected_phase: Phase):
"""Assert the current phase."""
assert int(self.gs.phase) == int(expected_phase), f"Expected phase {expected_phase}, got {self.gs.phase}"
def assert_legal_action(self, action_id: int):
"""Assert that an action is currently legal."""
legal = self.get_legal_actions()
assert action_id in legal, f"Action {action_id} is not legal. Legal: {legal}"
def log(self, msg: str):
"""Helper for logging if needed."""
print(f"[TEST] {msg}")
def print_rule_log(self, limit: int = 10):
"""Print the recent entries from the engine's rule log."""
log = self.gs.rule_log
start = max(0, len(log) - limit)
for i in range(start, len(log)):
print(f" {log[i]}")