rabukasim / engine_rust_src /src /py_bindings.rs
trioskosmos's picture
Upload folder using huggingface_hub
463f868 verified
use crate::core::alphazero_encoding::AlphaZeroEncoding;
use crate::core::alphazero_evaluator::AlphaZeroEvaluator;
use crate::core::heuristics::{EvalMode, HeuristicConfig, LegacyHeuristic, OriginalHeuristic};
use crate::core::logic::constants::STAGE_SLOT_COUNT;
use crate::core::logic::{ChoiceType, GameState, Phase, PlayerState, StandardizedState};
use crate::core::mcts::{SearchHorizon, MCTS};
use numpy::{PyArray1, PyArray2, PyArrayMethods, PyReadonlyArray1};
use pyo3::prelude::*;
use rayon::prelude::*;
use rand::prelude::*;
use rand::rngs::SmallRng;
use smallvec::SmallVec;
use std::sync::Arc;
use serde_json::json;
// use crate::core::heuristics::{OriginalHeuristic, SimpleHeuristic};
#[pyclass]
#[derive(Clone)]
pub struct PyPlayerState {
pub inner: PlayerState,
}
#[pymethods]
impl PyPlayerState {
#[getter]
fn player_id(&self) -> u8 {
self.inner.player_id
}
#[getter]
fn score(&self) -> u32 {
self.inner.score
}
#[setter(score)]
fn set_score_prop(&mut self, val: u32) {
self.set_score(val);
}
fn set_score(&mut self, val: u32) {
self.inner.score = val;
}
#[getter]
fn success_lives(&self) -> Vec<u32> {
self.inner.success_lives.iter().map(|&x| x as u32).collect()
}
#[setter(success_lives)]
fn set_success_lives_prop(&mut self, val: Vec<u32>) {
self.set_success_lives(val);
}
fn set_success_lives(&mut self, val: Vec<u32>) {
self.inner.success_lives = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn hand(&self) -> Vec<u32> {
self.inner.hand.iter().map(|&x| x as u32).collect()
}
#[setter(hand)]
fn set_hand_prop(&mut self, val: Vec<u32>) {
self.set_hand(val);
}
fn set_hand(&mut self, val: Vec<u32>) {
self.inner.hand = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn stage(&self) -> [i32; 3] {
[
self.inner.stage[0] as i32,
self.inner.stage[1] as i32,
self.inner.stage[2] as i32,
]
}
#[setter(stage)]
fn set_stage_prop(&mut self, val: [i32; 3]) {
self.set_stage(val);
}
fn set_stage(&mut self, val: [i32; 3]) {
self.inner.stage = [val[0], val[1], val[2]];
}
#[getter]
fn discard(&self) -> Vec<u32> {
self.inner.discard.iter().map(|&x| x as u32).collect()
}
#[setter(discard)]
fn set_discard_prop(&mut self, val: Vec<u32>) {
self.set_discard(val);
}
fn set_discard(&mut self, val: Vec<u32>) {
self.inner.discard = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn exile(&self) -> Vec<u32> {
self.inner.exile.iter().map(|&x| x as u32).collect()
}
#[getter]
fn deck(&self) -> Vec<u32> {
self.inner.deck.iter().map(|&x| x as u32).collect()
}
#[setter(deck)]
fn set_deck_prop(&mut self, val: Vec<u32>) {
self.set_deck(val);
}
fn set_deck(&mut self, val: Vec<u32>) {
self.inner.deck = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn initial_deck(&self) -> Vec<u32> {
self.inner.initial_deck.iter().map(|&x| x as u32).collect()
}
#[setter(initial_deck)]
fn set_initial_deck_prop(&mut self, val: Vec<u32>) {
self.set_initial_deck(val);
}
fn set_initial_deck(&mut self, val: Vec<u32>) {
self.inner.initial_deck = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn energy_zone(&self) -> Vec<u32> {
self.inner.energy_zone.iter().map(|&x| x as u32).collect()
}
#[setter(energy_zone)]
fn set_energy_zone_prop(&mut self, val: Vec<u32>) {
self.set_energy_zone(val);
}
fn set_energy_zone(&mut self, val: Vec<u32>) {
self.inner.energy_zone = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn energy_deck(&self) -> Vec<u32> {
self.inner.energy_deck.iter().map(|&x| x as u32).collect()
}
#[setter(energy_deck)]
fn set_energy_deck_prop(&mut self, val: Vec<u32>) {
self.set_energy_deck(val);
}
fn set_energy_deck(&mut self, val: Vec<u32>) {
self.inner.energy_deck = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn live_zone(&self) -> [i32; 3] {
[
self.inner.live_zone[0] as i32,
self.inner.live_zone[1] as i32,
self.inner.live_zone[2] as i32,
]
}
#[setter(live_zone)]
fn set_live_zone_prop(&mut self, val: [i32; 3]) {
self.set_live_zone(val);
}
fn set_live_zone(&mut self, val: [i32; 3]) {
self.inner.live_zone = [val[0], val[1], val[2]];
}
#[getter]
fn live_zone_revealed(&self) -> [bool; 3] {
[
self.inner.is_revealed(0),
self.inner.is_revealed(1),
self.inner.is_revealed(2),
]
}
#[setter(live_zone_revealed)]
fn set_live_zone_revealed_prop(&mut self, val: [bool; 3]) {
self.set_live_zone_revealed(val);
}
fn set_live_zone_revealed(&mut self, val: [bool; 3]) {
for (i, &v) in val.iter().enumerate() {
self.inner.set_revealed(i, v);
}
}
#[getter]
fn tapped_energy(&self) -> Vec<bool> {
(0..self.inner.energy_zone.len())
.map(|i| self.inner.is_energy_tapped(i))
.collect()
}
#[setter(tapped_energy)]
fn set_tapped_energy_prop(&mut self, val: Vec<bool>) {
self.set_tapped_energy(val);
}
fn set_tapped_energy(&mut self, val: Vec<bool>) {
self.inner.tapped_energy_mask = 0;
for (i, &tapped) in val.iter().enumerate() {
if tapped {
self.inner.set_energy_tapped(i, true);
}
}
}
#[getter]
fn tapped_members(&self) -> [bool; 3] {
[
self.inner.is_tapped(0),
self.inner.is_tapped(1),
self.inner.is_tapped(2),
]
}
#[setter(tapped_members)]
fn set_tapped_members_prop(&mut self, val: [bool; 3]) {
self.set_tapped_members(val);
}
fn set_tapped_members(&mut self, val: [bool; 3]) {
for i in 0..3 {
self.inner.set_tapped(i, val[i]);
}
}
#[setter(moved_members_this_turn)]
fn set_moved_members_this_turn_prop(&mut self, val: [bool; 3]) {
self.set_moved_members_this_turn(val);
}
fn set_moved_members_this_turn(&mut self, val: [bool; 3]) {
for i in 0..3 {
self.inner.set_moved(i, val[i]);
}
}
#[getter]
fn base_revealed_cards(&self) -> Vec<u32> {
self.inner.looked_cards.iter().map(|&x| x as u32).collect()
}
#[setter(base_revealed_cards)]
fn set_base_revealed_cards_prop(&mut self, val: Vec<u32>) {
self.set_base_revealed_cards(val);
}
fn set_base_revealed_cards(&mut self, val: Vec<u32>) {
self.inner.looked_cards = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn revealed_cards(&self) -> Vec<u32> {
self.inner.looked_cards.iter().map(|&x| x as u32).collect()
}
#[setter(revealed_cards)]
fn set_revealed_cards_prop(&mut self, val: Vec<u32>) {
self.set_revealed_cards(val);
}
fn set_revealed_cards(&mut self, val: Vec<u32>) {
self.inner.looked_cards = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn looked_cards(&self) -> Vec<u32> {
self.inner.looked_cards.iter().map(|&x| x as u32).collect()
}
fn set_looked_cards(&mut self, val: Vec<u32>) {
self.inner.looked_cards = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn deck_count(&self) -> usize {
self.inner.deck.len()
}
#[getter]
fn energy_deck_count(&self) -> usize {
self.inner.energy_deck.len()
}
#[getter]
fn mulligan_selection(&self) -> u64 {
self.inner.mulligan_selection
}
#[setter(mulligan_selection)]
fn set_mulligan_selection_prop(&mut self, val: u64) {
self.set_mulligan_selection(val);
}
fn set_mulligan_selection(&mut self, val: u64) {
self.inner.mulligan_selection = val;
}
#[getter]
fn baton_touch_count(&self) -> u32 {
self.inner.baton_touch_count as u32
}
#[setter(baton_touch_count)]
fn set_baton_touch_count_prop(&mut self, val: u32) {
self.set_baton_touch_count(val);
}
fn set_baton_touch_count(&mut self, val: u32) {
self.inner.baton_touch_count = val as u8;
}
#[getter]
fn baton_touch_limit(&self) -> u32 {
self.inner.baton_touch_limit as u32
}
#[setter(baton_touch_limit)]
fn set_baton_touch_limit_prop(&mut self, val: u32) {
self.set_baton_touch_limit(val);
}
fn set_baton_touch_limit(&mut self, val: u32) {
self.inner.baton_touch_limit = val as u8;
}
#[getter]
fn hand_added_turn(&self) -> Vec<u32> {
self.inner
.hand_added_turn
.iter()
.map(|&x| x as u32)
.collect()
}
#[setter(hand_added_turn)]
fn set_hand_added_turn_prop(&mut self, val: Vec<u32>) {
self.set_hand_added_turn(val);
}
fn set_hand_added_turn(&mut self, val: Vec<u32>) {
self.inner.hand_added_turn = val.into_iter().map(|x| x as i32).collect();
}
#[getter]
fn yell_cards(&self) -> Vec<u32> {
Vec::new()
}
#[setter(yell_cards)]
fn set_yell_cards_prop(&mut self, _val: Vec<u32>) {}
fn set_yell_cards(&mut self, _val: Vec<u32>) {}
#[getter]
pub fn heart_buffs(&self) -> Vec<Vec<i32>> {
self.inner
.heart_buffs
.iter()
.map(|h| h.to_array().iter().map(|&x| x as i32).collect())
.collect()
}
#[setter(heart_buffs)]
pub fn set_heart_buffs_prop(&mut self, val: Vec<Vec<i32>>) {
self.set_heart_buffs(val);
}
pub fn set_heart_buffs(&mut self, val: Vec<Vec<i32>>) {
for (i, v) in val.iter().enumerate() {
if i < 3 && v.len() == 7 {
for (j, &heart) in v.iter().enumerate() {
self.inner.heart_buffs[i].set_color_count(j, heart.max(0).min(255) as u8);
}
}
}
}
#[getter]
pub fn blade_buffs(&self) -> Vec<i32> {
self.inner.blade_buffs.iter().map(|&x| x as i32).collect()
}
#[setter(blade_buffs)]
pub fn set_blade_buffs_prop(&mut self, val: Vec<i32>) {
self.set_blade_buffs(val);
}
pub fn set_blade_buffs(&mut self, val: Vec<i32>) {
for (i, &v) in val.iter().enumerate() {
if i < 3 {
self.inner.blade_buffs[i] = v as i16;
}
}
}
#[getter]
pub fn activated_energy_group_mask(&self) -> u32 {
self.inner.activated_energy_group_mask
}
#[setter(activated_energy_group_mask)]
pub fn set_activated_energy_group_mask(&mut self, val: u32) {
self.inner.activated_energy_group_mask = val;
}
#[getter]
pub fn activated_member_group_mask(&self) -> u32 {
self.inner.activated_member_group_mask
}
#[setter(activated_member_group_mask)]
pub fn set_activated_member_group_mask(&mut self, val: u32) {
self.inner.activated_member_group_mask = val;
}
#[getter]
pub fn flags(&self) -> u32 {
self.inner.flags
}
#[setter(flags)]
pub fn set_flags(&mut self, val: u32) {
self.inner.flags = val;
}
}
#[pyclass]
#[derive(Clone)]
pub struct PyPendingInteraction {
#[pyo3(get)]
pub choice_type: String,
#[pyo3(get)]
pub filter_attr: u64,
#[pyo3(get)]
pub ctx: String, // Stringified AbilityContext
}
#[pyclass]
#[derive(Clone)]
pub struct PyCardDatabase {
pub inner: std::sync::Arc<crate::core::logic::CardDatabase>,
}
#[pymethods]
impl PyCardDatabase {
#[new]
fn new(json_str: &str) -> PyResult<Self> {
let db = crate::core::logic::CardDatabase::from_json(json_str).map_err(
|e: serde_json::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()),
)?;
Ok(Self {
inner: std::sync::Arc::new(db),
})
}
#[getter]
fn member_count(&self) -> usize {
self.inner.members.len()
}
#[getter]
fn live_count(&self) -> usize {
self.inner.lives.len()
}
fn has_member(&self, card_id: u32) -> bool {
self.inner.members.contains_key(&(card_id as i32))
}
fn get_member_ids(&self) -> Vec<u32> {
self.inner.members.keys().map(|&k| k as u32).collect()
}
fn get_live_ids(&self) -> Vec<u32> {
self.inner.lives.keys().map(|&k| k as u32).collect()
}
fn get_energy_ids(&self) -> Vec<u32> {
self.inner.energy_db.keys().map(|&k| k as u32).collect()
}
#[getter]
fn is_vanilla(&self) -> bool {
self.inner.is_vanilla
}
#[setter]
fn set_is_vanilla(&mut self, val: bool) {
if let Some(db) = std::sync::Arc::get_mut(&mut self.inner) {
db.is_vanilla = val;
} else {
let mut new_inner = (*self.inner).clone();
new_inner.is_vanilla = val;
self.inner = std::sync::Arc::new(new_inner);
}
}
}
#[pyclass]
pub struct PyGameState {
pub inner: GameState,
pub db: PyCardDatabase,
pub legal_action_buffer: Vec<bool>,
}
#[pymethods]
impl PyGameState {
#[new]
fn new(db: PyCardDatabase) -> PyResult<Self> {
Ok(Self {
inner: GameState::default(),
db,
legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE],
})
}
pub fn copy(&self) -> Self {
Self {
inner: self.inner.clone(),
db: self.db.clone(),
legal_action_buffer: vec![false; crate::core::logic::ACTION_SPACE],
}
}
pub fn ping(&self) -> String {
"pong_v_force_fix_1215".to_string()
}
#[getter]
fn db(&self) -> PyCardDatabase {
self.db.clone()
}
#[getter]
fn current_player(&self) -> u8 {
self.inner.current_player
}
#[setter]
fn set_current_player(&mut self, val: u8) {
self.inner.current_player = val;
}
#[getter]
fn first_player(&self) -> u8 {
self.inner.first_player
}
#[setter]
fn set_first_player(&mut self, val: u8) {
self.inner.first_player = val;
}
#[getter]
fn rps_choices(&self) -> [i8; 2] {
self.inner.rps_choices
}
#[getter]
fn rule_log(&self) -> Vec<String> {
self.inner.ui.rule_log.clone().unwrap_or_default()
}
#[getter]
fn trace_log(&self) -> Vec<String> {
self.inner.debug.trace_log.clone()
}
#[getter]
fn bytecode_log(&self) -> Vec<String> {
self.inner.ui.bytecode_log.clone()
}
fn clear_bytecode_log(&mut self) {
self.inner.ui.bytecode_log.clear();
}
#[getter]
fn silent(&self) -> bool {
self.inner.ui.silent
}
#[setter]
fn set_silent(&mut self, val: bool) {
self.inner.ui.silent = val;
}
#[getter]
fn turn_history(&self) -> Vec<String> {
if let Some(ref history) = self.inner.core.turn_history {
history.iter().map(|e| format!("{:?}", e)).collect()
} else {
Vec::new()
}
}
fn generate_execution_id(&mut self) -> u32 {
self.inner.generate_execution_id()
}
fn clear_execution_id(&mut self) {
self.inner.clear_execution_id();
}
fn get_current_execution_id(&self) -> Option<u32> {
self.inner.ui.current_execution_id
}
fn log(&mut self, msg: String) {
self.inner.log(msg);
}
#[getter]
fn phase(&self) -> i8 {
self.inner.phase as i8
}
#[getter]
fn phase_name(&self) -> String {
format!("{:?}", self.inner.phase)
}
#[getter]
fn acting_player(&self) -> u8 {
match self.inner.phase {
Phase::Response => {
if let Some(pi) = self.inner.interaction_stack.last() {
pi.ctx.player_id as u8
} else {
self.inner.current_player
}
}
_ => self.inner.current_player,
}
}
#[getter]
fn turn(&self) -> u32 {
self.inner.turn as u32
}
#[setter]
fn set_turn(&mut self, val: u32) {
self.inner.turn = val as u16;
}
#[getter]
fn debug_mode(&self) -> bool {
self.inner.debug.debug_mode
}
#[setter(debug_mode)]
fn set_debug_mode(&mut self, val: bool) {
self.inner.debug.debug_mode = val;
}
#[getter]
fn debug_ignore_conditions(&self) -> bool {
self.inner.debug.debug_ignore_conditions
}
#[setter(debug_ignore_conditions)]
fn set_debug_ignore_conditions(&mut self, val: bool) {
self.inner.debug.debug_ignore_conditions = val;
}
fn apply_state_json(&mut self, json_str: &str) -> PyResult<()> {
let new_state: GameState = serde_json::from_str(json_str).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Invalid state JSON: {}", e))
})?;
self.inner = new_state;
Ok(())
}
pub fn to_json(&self) -> PyResult<String> {
serde_json::to_string(&self.inner).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Serialization error: {}", e))
})
}
#[pyo3(signature = (room_id, mode, include_tensor=true, history=None))]
fn to_standardized_json(
&self,
room_id: String,
mode: String,
include_tensor: bool,
history: Option<Vec<PyRef<PyGameState>>>,
) -> PyResult<String> {
let mut room_info = std::collections::HashMap::new();
room_info.insert("id".to_string(), room_id);
room_info.insert("mode".to_string(), mode);
let rs_history = history.map(|h| h.into_iter().map(|gh| gh.inner.clone()).collect());
let std = StandardizedState::new(
self.inner.clone(),
&self.db.inner,
room_info,
include_tensor,
rs_history,
);
serde_json::to_string(&std).map_err(|e| {
pyo3::exceptions::PyValueError::new_err(format!("Serialization error: {}", e))
})
}
#[getter]
fn performance_results(&self) -> String {
serde_json::to_string(&self.inner.ui.performance_results).unwrap_or_default()
}
#[getter]
fn pending_card_id(&self) -> i32 {
self.inner
.interaction_stack
.last()
.map(|p| {
if p.card_id >= 0 {
p.card_id as i32
} else {
p.ctx.source_card_id as i32
}
})
.unwrap_or(-1)
}
#[getter]
fn pending_ab_idx(&self) -> i32 {
self.inner
.interaction_stack
.last()
.map(|p| p.ability_index as i32)
.unwrap_or(-1)
}
#[getter]
fn pending_effect_opcode(&self) -> i32 {
self.inner
.interaction_stack
.last()
.map(|p| p.effect_opcode as i32)
.unwrap_or(-1)
}
#[getter]
fn pending_choice_type(&self) -> String {
self.inner
.interaction_stack
.last()
.map(|p| p.choice_type.as_str().to_string())
.unwrap_or_default()
}
#[getter]
fn pending_choice_text(&self) -> String {
self.inner
.interaction_stack
.last()
.map(|p| p.choice_text.clone())
.unwrap_or_default()
}
#[getter]
fn yell_cards(&self) -> Vec<u32> {
// Moved to GameState
Vec::new()
}
#[setter]
fn set_yell_cards(&mut self, _val: Vec<u32>) {
// Moved to GameState
}
#[getter]
fn pending_area_idx(&self) -> i32 {
if let Some(pi) = self.inner.interaction_stack.last() {
pi.ctx.area_idx as i32
} else {
-1
}
}
#[getter]
fn pending_player_id(&self) -> i32 {
if let Some(pi) = self.inner.interaction_stack.last() {
pi.ctx.player_id as i32
} else {
-1
}
}
#[getter]
fn last_performance_results(&self) -> String {
serde_json::to_string(&self.inner.ui.last_performance_results)
.unwrap_or_else(|_| "{}".to_string())
}
#[getter]
fn performance_history(&self) -> String {
serde_json::to_string(&self.inner.ui.performance_history)
.unwrap_or_else(|_| "[]".to_string())
}
#[getter]
fn pending_choices(&self) -> Vec<(String, String)> {
use crate::core::enums::O_ACTIVATE_MEMBER;
use crate::core::enums::O_COLOR_SELECT;
use crate::core::enums::O_LOOK_AND_CHOOSE;
use crate::core::enums::O_MOVE_MEMBER;
use crate::core::enums::O_MOVE_TO_DISCARD;
use crate::core::enums::O_OPPONENT_CHOOSE;
use crate::core::enums::O_ORDER_DECK;
use crate::core::enums::O_PLAY_MEMBER_FROM_HAND;
use crate::core::enums::O_RECOVER_LIVE;
use crate::core::enums::O_RECOVER_MEMBER;
use crate::core::enums::O_REVEAL_CARDS;
use crate::core::enums::O_SELECT_CARDS;
use crate::core::enums::O_SELECT_MODE;
use crate::core::enums::O_TAP_OPPONENT;
let mut result = Vec::new();
let op = self
.inner
.interaction_stack
.last()
.map(|p| p.effect_opcode)
.unwrap_or(-1);
let p_idx = if let Some(pi) = self.inner.interaction_stack.last() {
pi.ctx.player_id as usize
} else {
self.inner.current_player as usize
};
let base_params = self
.inner
.interaction_stack
.last()
.map(|pi| {
let source_card_id = if pi.card_id >= 0 {
pi.card_id
} else {
pi.ctx.source_card_id
};
serde_json::json!({
"source_card_id": source_card_id,
"source_player": pi.ctx.player_id,
"source_area": pi.ctx.area_idx,
"area": pi.ctx.area_idx,
"ability_index": pi.ability_index,
"effect_opcode": pi.effect_opcode,
"target_slot": pi.target_slot,
"choice_text": pi.choice_text,
})
})
.unwrap_or_else(|| serde_json::json!({}));
if op == O_ORDER_DECK
|| op == O_LOOK_AND_CHOOSE
|| op == O_REVEAL_CARDS
|| op == O_RECOVER_LIVE
|| op == O_RECOVER_MEMBER
{
let looked = &self.inner.players[p_idx].looked_cards;
let mut params = base_params;
if let Some(obj) = params.as_object_mut() {
obj.insert("cards".to_string(), serde_json::json!(looked));
}
let type_str = if op == O_ORDER_DECK {
ChoiceType::OrderDeck.as_str()
} else {
"SELECT_FROM_LIST"
};
result.push((type_str.to_string(), params.to_string()));
} else if op == O_TAP_OPPONENT {
let mut params = base_params;
if let Some(obj) = params.as_object_mut() {
obj.insert(
"target_player".to_string(),
serde_json::json!(1 - self.inner.interaction_stack.last().map(|pi| pi.ctx.activator_id).unwrap_or(self.inner.current_player)),
);
}
result.push(("TARGET_OPPONENT_MEMBER".to_string(), params.to_string()));
} else if op == O_MOVE_MEMBER {
result.push(("MOVE_MEMBER".to_string(), base_params.to_string()));
} else if op == O_ACTIVATE_MEMBER {
result.push(("TAP_MEMBER".to_string(), base_params.to_string()));
} else if op == O_COLOR_SELECT {
result.push((ChoiceType::ColorSelect.as_str().to_string(), base_params.to_string()));
} else if op == O_MOVE_TO_DISCARD {
result.push((ChoiceType::SelectHandDiscard.as_str().to_string(), base_params.to_string()));
} else if op == O_PLAY_MEMBER_FROM_HAND {
result.push((ChoiceType::SelectHandPlay.to_string(), base_params.to_string()));
} else if op == O_SELECT_CARDS {
result.push(("SELECT_FROM_LIST".to_string(), base_params.to_string()));
} else if op == O_OPPONENT_CHOOSE {
result.push((ChoiceType::OpponentChoose.to_string(), base_params.to_string()));
} else if op == O_SELECT_MODE {
// We might need to store the options in the state if we want better labels
result.push((ChoiceType::SelectMode.to_string(), base_params.to_string()));
}
result
}
#[getter]
fn pending_effects(&self) -> Vec<String> {
Vec::new()
}
fn get_player(&self, idx: usize) -> PyResult<PyPlayerState> {
if idx < 2 {
Ok(PyPlayerState {
inner: self.inner.players[idx].clone(),
})
} else {
Err(pyo3::exceptions::PyIndexError::new_err(
"Player index out of bounds",
))
}
}
fn initialize_game(
&mut self,
p0_deck: Vec<u32>,
p1_deck: Vec<u32>,
p0_energy: Vec<u32>,
p1_energy: Vec<u32>,
p0_lives: Vec<u32>,
p1_lives: Vec<u32>,
) {
let p0_d: Vec<i32> = p0_deck.into_iter().map(|x| x as i32).collect();
let p1_d: Vec<i32> = p1_deck.into_iter().map(|x| x as i32).collect();
let p0_e: Vec<i32> = p0_energy.into_iter().map(|x| x as i32).collect();
let p1_e: Vec<i32> = p1_energy.into_iter().map(|x| x as i32).collect();
let p0_l: Vec<i32> = p0_lives.into_iter().map(|x| x as i32).collect();
let p1_l: Vec<i32> = p1_lives.into_iter().map(|x| x as i32).collect();
self.inner
.initialize_game(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l);
}
fn initialize_game_with_seed(
&mut self,
p0_deck: Vec<u32>,
p1_deck: Vec<u32>,
p0_energy: Vec<u32>,
p1_energy: Vec<u32>,
p0_lives: Vec<u32>,
p1_lives: Vec<u32>,
seed: u64,
) {
let p0_d: Vec<i32> = p0_deck.into_iter().map(|x| x as i32).collect();
let p1_d: Vec<i32> = p1_deck.into_iter().map(|x| x as i32).collect();
let p0_e: Vec<i32> = p0_energy.into_iter().map(|x| x as i32).collect();
let p1_e: Vec<i32> = p1_energy.into_iter().map(|x| x as i32).collect();
let p0_l: Vec<i32> = p0_lives.into_iter().map(|x| x as i32).collect();
let p1_l: Vec<i32> = p1_lives.into_iter().map(|x| x as i32).collect();
self.inner
.initialize_game_with_seed(p0_d, p1_d, p0_e, p1_e, p0_l, p1_l, Some(seed));
}
fn get_legal_actions(&mut self) -> Vec<bool> {
self.inner.get_legal_actions_into(
&self.db.inner,
self.inner.current_player as usize,
&mut self.legal_action_buffer,
);
self.legal_action_buffer.clone()
}
fn get_legal_action_ids(&mut self) -> Vec<i32> {
self.inner.get_legal_action_ids(&self.db.inner)
}
fn get_legal_action_ids_for_player(&mut self, p_idx: usize) -> Vec<i32> {
self.inner
.get_legal_action_ids_for_player(&self.db.inner, p_idx)
}
fn get_observation(&self) -> Vec<f32> {
self.inner.get_observation(&self.db.inner)
}
pub fn to_alphazero_tensor(&self) -> Vec<f32> {
self.inner.to_alphazero_tensor(&self.db.inner)
}
pub fn to_vanilla_tensor(&self) -> Vec<f32> {
use crate::core::alphazero_encoding_vanilla::AlphaZeroVanillaEncoding;
self.inner.to_vanilla_tensor(&self.db.inner)
}
}
// Second #[pymethods] block — PyO3 abi3 has a per-block inventory limit
#[pymethods]
impl PyGameState {
pub fn get_verbose_label(&self, action_id: i32) -> String {
crate::core::logic::ActionFactory::get_verbose_action_label(
action_id,
&self.inner,
&self.db.inner,
)
}
pub fn test_method(&self) -> String {
"test_ok_v3".to_string()
}
fn get_interaction(&self) -> Option<PyPendingInteraction> {
self.inner
.interaction_stack
.last()
.map(|pi| PyPendingInteraction {
choice_type: pi.choice_type.as_str().to_string(),
filter_attr: pi.filter_attr,
ctx: format!("{:?}", pi.ctx),
})
}
fn is_terminal(&self) -> bool {
self.inner.phase == Phase::Terminal
}
fn get_winner(&self) -> i32 {
self.inner.get_winner()
}
fn get_effective_blades(&self, p_idx: usize, slot_idx: usize) -> u32 {
self.inner
.get_effective_blades(p_idx, slot_idx, &self.db.inner, 0)
}
fn get_effective_hearts(&self, p_idx: usize, slot_idx: usize) -> [u8; 7] {
self.inner
.get_effective_hearts(p_idx, slot_idx, &self.db.inner, 0)
.to_array()
}
fn get_total_blades(&self, p_idx: usize) -> u32 {
self.inner.get_total_blades(p_idx, &self.db.inner, 0)
}
fn get_total_hearts(&self, p_idx: usize) -> [u32; 7] {
self.inner
.get_total_hearts(p_idx, &self.db.inner, 0)
.to_array()
.map(|x| x as u32)
}
fn get_member_cost(&self, p_idx: usize, card_id: i32, slot_idx: i32) -> i32 {
self.inner
.get_member_cost(p_idx, card_id, slot_idx as i16, -1, &self.db.inner, 0)
}
fn execute_mulligan(&mut self, player_idx: usize, discard_indices: Vec<usize>) {
self.inner.execute_mulligan(player_idx, discard_indices);
}
fn step(&mut self, action: i32) -> PyResult<()> {
let db = &self.db.inner;
if self.inner.debug.debug_mode {
self.inner.dump_diagnostics(db);
}
self.inner
.step(db, action)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
fn get_action_label(&self, action_id: i32) -> String {
crate::core::logic::ActionFactory::get_action_label(action_id)
}
fn auto_step(&mut self, _db: &PyCardDatabase) {
self.inner.auto_step(&self.db.inner);
}
fn debug_execute_bytecode(
&mut self,
bytecode: Vec<i32>,
player_id: u8,
area_idx: i32,
source_card_id: i32,
target_slot: i32,
choice_index: i32,
selected_color: i32,
) {
let db = &self.db.inner;
let ctx = crate::core::logic::AbilityContext {
player_id,
activator_id: player_id,
area_idx: area_idx as i16,
source_card_id,
target_card_id: -1,
target_slot: target_slot as i16,
choice_index: choice_index as i16,
selected_color: selected_color as i16,
program_counter: 0,
ability_index: -1,
v_remaining: -1,
trigger_type: Default::default(),
original_phase: None,
original_current_player: None,
repeat_count: 0,
selected_cards: Vec::new(),
v_accumulated: 0,
auto_pick: false,
};
self.inner
.resolve_bytecode(db, std::sync::Arc::new(bytecode), &ctx);
}
fn integrated_step(
&mut self,
action: i32,
opp_mode: u8,
mcts_sims: usize,
enable_rollout: bool,
) -> (f32, bool) {
let db = &self.db.inner;
self.inner
.integrated_step(db, action, opp_mode, mcts_sims, enable_rollout)
}
#[pyo3(signature = (p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon=SearchHorizon::GameEnd(), p0_rollout=true, p1_rollout=true))]
fn play_asymmetric_match(
&mut self,
p0_sims: usize,
p1_sims: usize,
p0_heuristic_id: i32,
p1_heuristic_id: i32,
horizon: SearchHorizon,
p0_rollout: bool,
p1_rollout: bool,
) -> (i32, u32) {
let db = &self.db.inner;
self.inner.play_asymmetric_match(
db,
p0_sims,
p1_sims,
p0_heuristic_id,
p1_heuristic_id,
horizon,
p0_rollout,
p1_rollout,
)
}
#[pyo3(signature = (p0_sims, p1_sims, p0_heuristic_id, p1_heuristic_id, horizon=SearchHorizon::GameEnd(), enable_rollout=true))]
fn play_mirror_match(
&mut self,
p0_sims: usize,
p1_sims: usize,
p0_heuristic_id: i32,
p1_heuristic_id: i32,
horizon: SearchHorizon,
enable_rollout: bool,
) -> (i32, u32) {
let db = &self.db.inner;
self.inner.play_mirror_match(
db,
p0_sims,
p1_sims,
p0_heuristic_id,
p1_heuristic_id,
horizon,
enable_rollout,
)
}
fn step_opponent(&mut self) {
let db = &self.db.inner;
self.inner.step_opponent(db);
}
#[pyo3(signature = (sims, config=None))]
fn step_opponent_mcts(&mut self, sims: usize, config: Option<HeuristicConfig>) {
let db = &self.db.inner;
let h = OriginalHeuristic {
config: config.unwrap_or_default(),
};
self.inner.step_opponent_mcts(db, sims, &h);
}
#[pyo3(signature = (config=None))]
fn step_opponent_greedy(&mut self, config: Option<HeuristicConfig>) {
let db = &self.db.inner;
let h = OriginalHeuristic {
config: config.unwrap_or_default(),
};
self.inner.step_opponent_greedy(db, &h);
}
/// Execute opponent's full turn using TurnSequencer planner for vanilla mode.
/// This uses the success-count-first heuristic optimized for lower turn counts.
#[pyo3(signature = ())]
fn step_opponent_turnseq(&mut self) {
use crate::core::logic::turn_sequencer::TurnSequencer;
let db = &self.db.inner;
let (action_seq, _, _, _) = TurnSequencer::plan_full_turn(&self.inner, db);
// Execute each action in the sequence until PASS or game ends
for &action in &action_seq {
if self.inner.is_terminal() {
break;
}
if let Err(_) = self.inner.step(db, action) {
break;
}
}
}
#[pyo3(signature = (_db, p_idx, heuristic_id, config=None))]
fn get_greedy_action(
&mut self,
_db: &PyCardDatabase,
p_idx: usize,
heuristic_id: i32,
config: Option<HeuristicConfig>,
) -> i32 {
let db = &self.db.inner;
let cfg = config.unwrap_or_default();
match heuristic_id {
1 => self
.inner
.get_greedy_action(db, p_idx, &LegacyHeuristic { config: cfg }),
2 => self
.inner
.get_greedy_action(db, p_idx, &LegacyHeuristic { config: cfg }),
_ => self
.inner
.get_greedy_action(db, p_idx, &OriginalHeuristic { config: cfg }),
}
}
#[pyo3(signature = (_db, p_idx, heuristic_id, config=None))]
fn get_greedy_evaluations(
&mut self,
_db: &PyCardDatabase,
p_idx: usize,
heuristic_id: i32,
config: Option<HeuristicConfig>,
) -> Vec<(i32, f32)> {
let db = &self.db.inner;
let cfg = config.unwrap_or_default();
match heuristic_id {
1 => self
.inner
.get_greedy_evaluations(db, p_idx, &LegacyHeuristic { config: cfg }),
2 => self
.inner
.get_greedy_evaluations(db, p_idx, &LegacyHeuristic { config: cfg }),
_ => self
.inner
.get_greedy_evaluations(db, p_idx, &OriginalHeuristic { config: cfg }),
}
}
#[pyo3(signature = (heuristic_id, baseline_score0=0, baseline_score1=0, config=None))]
fn evaluate(
&self,
heuristic_id: i32,
baseline_score0: u32,
baseline_score1: u32,
config: Option<HeuristicConfig>,
) -> f32 {
let db = &self.db.inner;
let cfg = config.unwrap_or_default();
match heuristic_id {
1 => self.inner.evaluate(
db,
baseline_score0,
baseline_score1,
EvalMode::Normal,
&LegacyHeuristic { config: cfg },
),
2 => self.inner.evaluate(
db,
baseline_score0,
baseline_score1,
EvalMode::Normal,
&LegacyHeuristic { config: cfg },
),
_ => self.inner.evaluate(
db,
baseline_score0,
baseline_score1,
EvalMode::Normal,
&OriginalHeuristic { config: cfg },
),
}
}
#[pyo3(signature = (sims, timeout_sec=0.0, horizon=SearchHorizon::GameEnd(), eval_mode=EvalMode::Blind))]
fn get_mcts_suggestions(
&mut self,
sims: usize,
timeout_sec: f32,
horizon: SearchHorizon,
eval_mode: EvalMode,
) -> Vec<(i32, f32, u32)> {
let db = &self.db.inner;
self.inner
.get_mcts_suggestions(db, sims, timeout_sec, horizon, eval_mode)
}
#[pyo3(signature = (sims, timeout_sec=0.0, horizon=SearchHorizon::GameEnd(), eval_mode=EvalMode::Blind, config=None))]
fn get_mcts_suggestions_with_config(
&mut self,
sims: usize,
timeout_sec: f32,
horizon: SearchHorizon,
eval_mode: EvalMode,
config: Option<HeuristicConfig>,
) -> Vec<(i32, f32, u32)> {
let db = &self.db.inner;
let h = OriginalHeuristic {
config: config.unwrap_or_default(),
};
self.inner
.get_mcts_suggestions_ext(db, sims, timeout_sec, horizon, eval_mode, &h)
}
#[pyo3(signature = (sims, evaluator, batch_size=16))]
fn search_mcts_alphazero(
&mut self,
sims: usize,
evaluator: &PyAlphaZeroEvaluator,
batch_size: usize,
) -> Vec<(i32, f32, u32)> {
let db = &self.db.inner;
let mut mcts = MCTS::with_evaluator(evaluator.evaluator.clone(), batch_size);
let h = OriginalHeuristic::default();
let (suggestions, _profiler) = mcts.search(
&self.inner,
db,
sims,
0.0,
SearchHorizon::GameEnd(),
&h,
);
suggestions
}
#[setter]
fn set_phase(&mut self, val: i8) {
self.inner.phase = match val {
-1 => Phase::MulliganP1,
0 => Phase::MulliganP2,
1 => Phase::Active,
2 => Phase::Energy,
3 => Phase::Draw,
4 => Phase::Main,
5 => Phase::LiveSet,
6 => Phase::PerformanceP1,
7 => Phase::PerformanceP2,
8 => Phase::LiveResult,
9 => Phase::Terminal,
10 => Phase::Response,
_ => Phase::Setup,
};
}
fn set_player(&mut self, idx: usize, player: PyPlayerState) -> PyResult<()> {
if idx < 2 {
self.inner.log(format!(
"set_player {}: Discard len = {}",
idx,
player.inner.discard.len()
));
self.inner.players[idx] = player.inner;
Ok(())
} else {
Err(pyo3::exceptions::PyIndexError::new_err(
"Player index out of bounds",
))
}
}
fn set_stage_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32) {
if p_idx < 2 && slot_idx < STAGE_SLOT_COUNT {
self.inner.players[p_idx].stage[slot_idx] = card_id;
}
}
fn set_live_card(&mut self, p_idx: usize, slot_idx: usize, card_id: i32, revealed: bool) {
if p_idx < 2 && slot_idx < STAGE_SLOT_COUNT {
self.inner.players[p_idx].live_zone[slot_idx] = card_id;
self.inner.players[p_idx].set_revealed(slot_idx, revealed);
}
}
fn set_hand_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
self.inner.players[p_idx].hand = cards.into_iter().map(|x| x as i32).collect();
self.inner.players[p_idx].hand_added_turn =
SmallVec::from_vec(vec![
self.inner.turn as i32;
self.inner.players[p_idx].hand.len()
]);
}
}
fn set_discard_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
self.inner.players[p_idx].discard = cards.into_iter().map(|x| x as i32).collect();
}
}
fn set_revealed_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
// looked_cards is the shared buffer for revealing cards in the engine
self.inner.players[p_idx].looked_cards = cards.into_iter().map(|x| x as i32).collect();
}
}
fn set_deck_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
self.inner.players[p_idx].deck = cards.into_iter().map(|x| x as i32).collect();
}
}
fn set_energy_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
self.inner.players[p_idx].energy_zone = cards.into_iter().map(|x| x as i32).collect();
// Initialize tapped_energy if needed (reset mask)
self.inner.players[p_idx].tapped_energy_mask = 0;
}
}
fn set_live_cards(&mut self, p_idx: usize, cards: Vec<u32>) {
if p_idx < 2 {
for (i, &cid) in cards.iter().enumerate().take(3) {
self.inner.players[p_idx].live_zone[i] = cid as i32;
}
}
}
fn resolve_bytecode(&mut self, bytecode: Vec<i32>, player_id: u8, _area_idx: i32) {
let ctx = crate::core::logic::AbilityContext {
player_id,
activator_id: player_id,
target_card_id: -1,
original_phase: None,
..crate::core::logic::AbilityContext::default()
};
self.inner
.resolve_bytecode(&self.db.inner, std::sync::Arc::new(bytecode), &ctx);
}
fn trigger_abilities(&mut self, trigger: i32, player_id: u8) {
let trigger_type =
unsafe { std::mem::transmute::<i8, crate::core::enums::TriggerType>(trigger as i8) };
let ctx = crate::core::logic::AbilityContext {
player_id,
activator_id: player_id,
target_card_id: -1,
trigger_type: crate::core::enums::TriggerType::None,
..Default::default()
};
self.inner
.trigger_abilities(&self.db.inner, trigger_type, &ctx);
}
fn trigger_ability_on_card(
&mut self,
_player_id: u8,
_card_id: i32,
slot_idx: i32,
ab_idx: i32,
) -> PyResult<()> {
let db = &self.db.inner;
self.inner
.activate_ability(db, slot_idx as usize, ab_idx as usize)
.map_err(|e| pyo3::exceptions::PyValueError::new_err(e))
}
fn clear_once_per_turn_flags(&mut self, p_idx: usize) {
if p_idx < 2 {
self.inner.players[p_idx].used_abilities.clear();
}
}
fn start_turn(&mut self) {
self.inner.do_active_phase(&self.db.inner);
}
#[pyo3(signature = (num_sims=0, seconds=0.0, heuristic_type="original", horizon=SearchHorizon::GameEnd(), eval_mode=EvalMode::Blind, config=None, _model_path=None))]
fn search_mcts(
&self,
num_sims: usize,
seconds: f32,
heuristic_type: &str,
horizon: SearchHorizon,
eval_mode: EvalMode,
config: Option<HeuristicConfig>,
_model_path: Option<&str>,
) -> Vec<(i32, f32, u32)> {
let cfg = config.unwrap_or_default();
if heuristic_type == "resnet" || heuristic_type == "hybrid" {
// ... (keeping NN logic simplified for now as it's less commonly used in diagnostics)
#[cfg(not(feature = "nn"))]
{
let mcts = crate::core::mcts::MCTS::new();
let h = OriginalHeuristic { config: cfg };
return mcts.search_parallel(
&self.inner,
&self.db.inner,
num_sims,
seconds,
horizon,
&h,
eval_mode == EvalMode::Blind,
);
}
}
let mcts = crate::core::mcts::MCTS::new();
match heuristic_type {
"legacy" => {
let h = LegacyHeuristic { config: cfg };
mcts.search_parallel(
&self.inner,
&self.db.inner,
num_sims,
seconds,
horizon,
&h,
eval_mode == EvalMode::Blind,
)
}
_ => {
let h = OriginalHeuristic { config: cfg };
mcts.search_parallel(
&self.inner,
&self.db.inner,
num_sims,
seconds,
horizon,
&h,
eval_mode == EvalMode::Blind,
)
}
}
}
#[pyo3(signature = (db, num_games))]
pub fn sim_random_games(&self, db: &PyCardDatabase, num_games: usize) -> PyObject {
let mut rng = SmallRng::from_os_rng();
let mut total_moves = 0;
let mut total_meaningful_moves = 0;
let mut gameplay_seconds = 0.0;
let mut action_stats: std::collections::HashMap<String, (u64, f64)> = std::collections::HashMap::new();
for _ in 0..num_games {
let mut state = self.inner.clone();
state.ui.silent = true;
let start = std::time::Instant::now();
while state.phase != crate::core::logic::Phase::Terminal {
let mut actions = SmallVec::<[i32; 64]>::new();
state.generate_legal_actions(&db.inner, state.current_player as usize, &mut actions);
let action = if actions.is_empty() {
0
} else {
*actions.choose(&mut rng).unwrap()
};
let step_start = std::time::Instant::now();
let _ = state.step(&db.inner, action);
let step_duration = step_start.elapsed().as_secs_f64();
let label = crate::core::logic::ActionFactory::get_action_label(action);
// Categorize label to keep the map size reasonable
let category = if label.contains("Mulligan") {
"Mulligan".to_string()
} else if label.contains("Set Live") {
"SetLive".to_string()
} else if label.contains("Play Hand") {
"PlayMember".to_string()
} else if label.contains("Activate Member") {
"ActivateMember".to_string()
} else if label.contains("Activate from Hand") {
"ActivateFromHand".to_string()
} else if label.contains("Activate from Discard") {
"ActivateFromDiscard".to_string()
} else if label.contains("Pass") {
"Pass".to_string()
} else if label.contains("Select Mode") {
"SelectMode".to_string()
} else if label.contains("Select Color") {
"SelectColor".to_string()
} else if label.contains("Select Stage Slot")
|| label.contains("Select Left Slot")
|| label.contains("Select Mid Slot")
|| label.contains("Select Right Slot")
{
"SelectStageSlot".to_string()
} else if label.contains("Select Choice") {
"SelectChoice".to_string()
} else if label.contains("Turn Choice") {
"TurnChoice".to_string()
} else {
label
};
let entry = action_stats.entry(category).or_insert((0, 0.0));
entry.0 += 1;
entry.1 += step_duration;
total_moves += 1;
if action != 0 {
total_meaningful_moves += 1;
}
if total_moves > 1000000 {
// Safety break for extreme cases
break;
}
}
gameplay_seconds += start.elapsed().as_secs_f64();
}
Python::with_gil(|py| {
let mps = if gameplay_seconds > 0.0 { total_moves as f64 / gameplay_seconds } else { 0.0 };
let mut timing_breakdown = std::collections::HashMap::new();
for (cat, (count, total_time)) in action_stats {
timing_breakdown.insert(cat, json!({
"count": count,
"total_time": total_time,
"avg_time": if count > 0 { total_time / count as f64 } else { 0.0 }
}));
}
let results = json!({
"total_games": num_games,
"total_moves": total_moves,
"total_meaningful_moves": total_meaningful_moves,
"gameplay_seconds": gameplay_seconds,
"mps": mps,
"action_timings": timing_breakdown,
});
let json_str = results.to_string();
let json_mod = py.import("json").unwrap();
json_mod.call_method1("loads", (json_str,)).unwrap().to_object(py)
})
}
#[pyo3(signature = (db))]
pub fn plan_full_turn(&self, db: &PyCardDatabase) -> (Vec<(i32, f32, f32, f32)>, Vec<i32>, usize, (f32, f32)) {
use crate::core::logic::turn_sequencer::TurnSequencer;
let (seq, _val, breakdown, nodes) = TurnSequencer::plan_full_turn(&self.inner, &db.inner);
(Vec::new(), seq, nodes, breakdown)
}
#[pyo3(signature = (db))]
pub fn plan_full_turn_with_stats(&self, db: &PyCardDatabase) -> (Vec<(i32, f32, f32, f32)>, Vec<i32>, usize, f32, (f32, f32)) {
use crate::core::logic::turn_sequencer::TurnSequencer;
TurnSequencer::plan_full_turn_with_stats(&self.inner, &db.inner)
}
#[pyo3(signature = (db))]
pub fn find_best_liveset_selection(&self, db: &PyCardDatabase) -> (Vec<i32>, usize, u128) {
use crate::core::logic::turn_sequencer::TurnSequencer;
TurnSequencer::find_best_liveset_selection(&self.inner, &db.inner)
}
#[pyo3(signature = (db, p_idx))]
pub fn get_score_breakdown(
&self,
db: &PyCardDatabase,
p_idx: usize,
) -> (f32, f32, f32, f32, f32, f32, f32) {
use crate::core::logic::turn_sequencer::TurnSequencer;
let brk = TurnSequencer::get_score_breakdown(&self.inner, &db.inner, p_idx);
(
brk.board_score,
brk.live_ev,
brk.success_val,
brk.win_bonus,
brk.hand_momentum,
brk.cycling_bonus,
brk.total,
)
}
}
#[pyclass]
pub struct PyVectorGameState {
envs: Vec<GameState>,
db: PyCardDatabase,
p0_deck: Vec<u32>,
p1_deck: Vec<u32>,
p0_lives: Vec<u32>,
p1_lives: Vec<u32>,
seeds: Vec<u64>,
opp_mode: u8,
mcts_sims: usize,
}
#[pymethods]
impl PyVectorGameState {
#[new]
#[pyo3(signature = (num_envs, db, opp_mode=0, mcts_sims=50))]
fn new(num_envs: usize, db: PyCardDatabase, opp_mode: u8, mcts_sims: usize) -> Self {
let mut envs = Vec::with_capacity(num_envs);
for _ in 0..num_envs {
envs.push(GameState::default());
}
Self {
envs,
db,
p0_deck: Vec::new(),
p1_deck: Vec::new(),
p0_lives: Vec::new(),
p1_lives: Vec::new(),
seeds: vec![0; num_envs],
opp_mode,
mcts_sims,
}
}
fn initialize(
&mut self,
p0_deck: Vec<u32>,
p1_deck: Vec<u32>,
p0_lives: Vec<u32>,
p1_lives: Vec<u32>,
seed: u64,
) {
self.p0_deck = p0_deck;
self.p1_deck = p1_deck;
self.p0_lives = p0_lives;
self.p1_lives = p1_lives;
let num_envs = self.envs.len();
for i in 0..num_envs {
self.seeds[i] = seed + i as u64;
}
self.envs.par_iter_mut().enumerate().for_each(|(i, env)| {
env.initialize_game_with_seed(
self.p0_deck.iter().map(|&x| x as i32).collect(),
self.p1_deck.iter().map(|&x| x as i32).collect(),
Vec::new(),
Vec::new(),
self.p0_lives.iter().map(|&x| x as i32).collect(),
self.p1_lives.iter().map(|&x| x as i32).collect(),
Some(self.seeds[i]),
);
});
}
#[allow(clippy::too_many_arguments)]
fn step<'py>(
&mut self,
_py: Python<'py>,
actions: PyReadonlyArray1<'py, i32>,
obs_out: &Bound<'py, PyArray2<f32>>,
rewards_out: &Bound<'py, PyArray1<f32>>,
dones_out: &Bound<'py, PyArray1<bool>>,
term_obs_out: &Bound<'py, PyArray2<f32>>,
) -> PyResult<Vec<usize>> {
let actions = actions.as_slice()?;
let obs_slice = unsafe { obs_out.as_slice_mut()? };
let rewards_slice = unsafe { rewards_out.as_slice_mut()? };
let dones_slice = unsafe { dones_out.as_slice_mut()? };
let term_obs_slice = unsafe { term_obs_out.as_slice_mut()? };
let num_envs = self.envs.len();
let db = &self.db.inner;
let obs_dim = 320;
if actions.len() != num_envs {
return Err(pyo3::exceptions::PyValueError::new_err(
"Action dim mismatch",
));
}
// 1. Step
let opp_mode = self.opp_mode;
let mcts_sims = self.mcts_sims;
let results: Vec<(f32, bool)> = self
.envs
.par_iter_mut()
.zip(actions.par_iter())
.map(|(env, &act)| env.integrated_step(db, act, opp_mode, mcts_sims, true))
.collect();
results
.par_iter()
.zip(rewards_slice.par_iter_mut())
.zip(dones_slice.par_iter_mut())
.for_each(|((&(r, d), r_out), d_out)| {
*r_out = r;
*d_out = d;
});
// 2. Filter Done
let mut done_indices = Vec::with_capacity(num_envs / 10);
for (i, &(_, done)) in results.iter().enumerate() {
if done {
done_indices.push(i);
}
}
// 3. Write Terminal Obs (Before Reset)
if !done_indices.is_empty() {
term_obs_slice
.par_chunks_mut(obs_dim)
.zip(done_indices.par_iter())
.for_each(|(chunk, &env_idx)| {
self.envs[env_idx].write_observation(db, chunk);
});
}
// 4. Auto-Reset
let p0_deck = &self.p0_deck;
let p1_deck = &self.p1_deck;
let p0_lives = &self.p0_lives;
let p1_lives = &self.p1_lives;
self.envs
.par_iter_mut()
.zip(results.par_iter())
.for_each(|(env, &(_, done))| {
if done {
env.initialize_game_with_seed(
p0_deck.iter().map(|&x| x as i32).collect(),
p1_deck.iter().map(|&x| x as i32).collect(),
Vec::new(),
Vec::new(),
p0_lives.iter().map(|&x| x as i32).collect(),
p1_lives.iter().map(|&x| x as i32).collect(),
None,
);
}
});
// 5. Write Final Obs
obs_slice
.par_chunks_mut(obs_dim)
.zip(self.envs.par_iter())
.for_each(|(chunk, env)| {
env.write_observation(db, chunk);
});
Ok(done_indices)
}
// New: Zero-Copy get_observations
fn get_observations<'py>(
&self,
_py: Python<'py>,
out: &Bound<'py, PyArray2<f32>>,
) -> PyResult<()> {
let db = &self.db.inner;
let obs_dim = 320;
let obs_slice = unsafe { out.as_slice_mut()? };
obs_slice
.par_chunks_mut(obs_dim)
.zip(self.envs.par_iter())
.for_each(|(chunk, env)| {
env.write_observation(db, chunk);
});
Ok(())
}
// New: Zero-Copy get_action_masks
fn get_action_masks<'py>(
&self,
_py: Python<'py>,
out: &Bound<'py, PyArray2<bool>>,
) -> PyResult<()> {
let db = &self.db.inner;
let action_dim = crate::core::logic::ACTION_SPACE;
let mask_slice = unsafe { out.as_slice_mut()? };
mask_slice
.par_chunks_mut(action_dim)
.zip(self.envs.par_iter())
.for_each(|(chunk, env)| {
env.get_legal_actions_into(db, env.current_player as usize, chunk);
});
Ok(())
}
}
#[cfg(feature = "nn")]
#[pyclass]
pub struct PyHybridMCTS {
pub session: std::sync::Arc<std::sync::Mutex<ort::session::Session>>,
pub neural_weight: f32,
pub skip_rollout: bool,
}
#[cfg(feature = "nn")]
#[pymethods]
impl PyHybridMCTS {
#[new]
#[pyo3(signature = (model_path, neural_weight=0.3, skip_rollout=false))]
fn new(model_path: &str, neural_weight: f32, skip_rollout: bool) -> PyResult<Self> {
let session = ort::session::Session::builder()
.map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?
.commit_from_file(model_path)
.map_err(|e: ort::Error| pyo3::exceptions::PyValueError::new_err(e.to_string()))?;
Ok(Self {
session: std::sync::Arc::new(std::sync::Mutex::new(session)),
neural_weight,
skip_rollout,
})
}
#[pyo3(signature = (game, num_sims=0, seconds=0.0))]
fn get_suggestions(
&mut self,
game: &mut PyGameState,
num_sims: usize,
seconds: f32,
) -> Vec<(i32, f32, u32)> {
let mut mcts = crate::core::mcts::HybridMCTS::new(
self.session.clone(),
self.neural_weight,
self.skip_rollout,
);
mcts.get_suggestions(&game.inner, &game.db.inner, num_sims, seconds)
}
}
// AlphaZero Tensor Type Enum
#[pyclass(eq, eq_int)]
#[derive(Clone, Copy, PartialEq, Eq, Debug, Default)]
pub enum AlphaZeroTensorType {
#[default]
Vanilla = 0,
Original = 1,
}
// PyAlphaZeroEvaluator wrapper for network-guided MCTS
#[pyclass]
pub struct PyAlphaZeroEvaluator {
evaluator: Arc<Box<dyn AlphaZeroEvaluator>>,
}
#[pymethods]
impl PyAlphaZeroEvaluator {
#[new]
fn new(model: PyObject, tensor_type: AlphaZeroTensorType) -> Self {
#[cfg(feature = "extension-module")]
{
let tensor_encoding = match tensor_type {
AlphaZeroTensorType::Vanilla => crate::core::alphazero_evaluator::PythonTensorEncoding::Vanilla,
AlphaZeroTensorType::Original => crate::core::alphazero_evaluator::PythonTensorEncoding::Original,
};
let evaluator_impl = crate::core::alphazero_evaluator::PyAlphaZeroEvaluator::new(model, tensor_encoding);
Self {
evaluator: Arc::new(Box::new(evaluator_impl)),
}
}
#[cfg(not(feature = "extension-module"))]
{
panic!("PyAlphaZeroEvaluator requires extension-module feature");
}
}
}
pub fn register_python_module(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyPlayerState>()?;
m.add_class::<PyGameState>()?;
m.add_class::<PyVectorGameState>()?;
m.add_class::<PyCardDatabase>()?;
m.add_class::<PyPendingInteraction>()?;
#[cfg(feature = "nn")]
m.add_class::<PyHybridMCTS>()?;
m.add_class::<SearchHorizon>()?;
m.add_class::<EvalMode>()?;
m.add_class::<HeuristicConfig>()?;
m.add_class::<AlphaZeroTensorType>()?;
m.add_class::<PyAlphaZeroEvaluator>()?;
Ok(())
}