using Random using StatsBase """ simluate_rollout(b::Board, policy, side [rng=MersenneTwister(420)) Simulate one rollout of a simulation based on the given `Chess.board` state. Policy is a function, given the board and `MoveList`, returns an `AbstractArray` of probability weights for each `Move` in `Move`List` based on index. """ function simulate_rollout(b::Board, policy, side; rng = MersenneTwister(420))::Tuple{Board, Int64} #pprint(b) # Debugging movelist = MoveList(200) num_sim_moves = 0 while !isterminal(b) # TODO Use `matein1` possibly to trim leaf nodes in sims? moves(b, movelist) policy_weights = ProbabilityWeights(policy(b, movelist)) #pprint(b) #println(movelist, policy_weights) domove!(b, sample(movelist, policy_weights)) recycle!(movelist) num_sim_moves += 1 end return b, num_sim_moves end """ CESPF(b::Board, movelist::MoveList) Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece (F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which move we prefer to take. """ function CESPF(b::Board, movelist::MoveList) unnorm_policy_weights = map(x -> see(b, x), movelist) # Center raw centipawn values to 1 to then normalize centered_policy_weights = (1 + abs(min(unnorm_policy_weights...))) .+ unnorm_policy_weights return centered_policy_weights / sum(centered_policy_weights) end """ CESPF_greedy(b::Board, movelist::MoveList) Utilizes `Chess.jl`'s `see()` function to simulate (C)apture / (E)scape (S)tronger (P)iece (F)irst heuristic in simulation/rollout policy. We use Chess weights set in `see` function to get weight for which move we prefer to take. This is greedy, and will set only the maximal valued policies to a non-zero probability """ function CESPF_greedy(b::Board, movelist::MoveList) unnorm_policy_weights = map(x -> see(b, x), movelist) policy_weights = zeros(length(unnorm_policy_weights)) max_idxs = findall(unnorm_policy_weights .== maximum(unnorm_policy_weights)) for max_idx in max_idxs policy_weights[max_idx] = 1.0 / length(max_idxs) end return policy_weights end