zjowowen's picture
init space
079c32c
raw
history blame
3.35 kB
# distutils: language=c++
# cython:language_level=3
from libcpp.vector cimport vector
cdef class MinMaxStatsList:
cdef CMinMaxStatsList *cmin_max_stats_lst
def __cinit__(self, int num):
self.cmin_max_stats_lst = new CMinMaxStatsList(num)
def set_delta(self, float value_delta_max):
self.cmin_max_stats_lst[0].set_delta(value_delta_max)
def __dealloc__(self):
del self.cmin_max_stats_lst
cdef class ResultsWrapper:
cdef CSearchResults cresults
def __cinit__(self, int num):
self.cresults = CSearchResults(num)
def get_search_len(self):
return self.cresults.search_lens
cdef class Roots:
cdef int root_num
cdef CRoots *roots
def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list):
self.root_num = root_num
self.roots = new CRoots(root_num, legal_actions_list)
def prepare(self, float root_noise_weight, list noises, list value_prefix_pool, list policy_logits_pool,
vector[int] & to_play_batch):
self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch)
def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch):
self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch)
def get_trajectories(self):
return self.roots[0].get_trajectories()
def get_distributions(self):
return self.roots[0].get_distributions()
def get_values(self):
return self.roots[0].get_values()
def clear(self):
self.roots[0].clear()
def __dealloc__(self):
del self.roots
@property
def num(self):
return self.root_num
cdef class Node:
cdef CNode cnode
def __cinit__(self):
pass
def __cinit__(self, float prior, vector[int] & legal_actions):
pass
def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix,
list policy_logits):
cdef vector[float] cpolicy = policy_logits
self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy)
def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list to_play_batch):
cdef int i
cdef vector[float] cvalue_prefixs = value_prefixs
cdef vector[float] cvalues = values
cdef vector[vector[float]] cpolicies = policies
cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
min_max_stats_lst.cmin_max_stats_lst, results.cresults, to_play_batch)
def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
ResultsWrapper results, list virtual_to_play_batch):
cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst, results.cresults,
virtual_to_play_batch)
return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs