Spaces:
Runtime error
Runtime error
import numpy as np | |
import random | |
from ray.rllib.policy.policy import Policy | |
from ray.rllib.utils.annotations import override | |
from ray.rllib.models.modelv2 import restore_original_dimensions | |
class HeuristicBase(Policy): | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.exploration = self._create_exploration() | |
def learn_on_batch(self, samples): | |
pass | |
def get_weights(self): | |
"""No weights to save.""" | |
return {} | |
def set_weights(self, weights): | |
"""No weights to set.""" | |
pass | |
def compute_actions( | |
self, | |
obs_batch, | |
state_batches=None, | |
prev_action_batch=None, | |
prev_reward_batch=None, | |
info_batch=None, | |
episodes=None, | |
**kwargs | |
): | |
obs_batch = restore_original_dimensions( | |
np.array(obs_batch, dtype=np.float32), self.observation_space, tensorlib=np | |
) | |
return self._do_compute_actions(obs_batch) | |
def pick_legal_action(self, legal_action): | |
legal_choices = np.arange(len(legal_action))[legal_action == 1] | |
return np.random.choice(legal_choices) | |
class AlwaysSameHeuristic(HeuristicBase): | |
""" | |
Pick a random move and stick with it for the entire episode. | |
""" | |
_rand_choice = random.choice(range(7)) | |
def _do_compute_actions(self, obs_batch): | |
def select_action(legal_action): | |
legal_choices = np.arange(len(legal_action))[legal_action == 1] | |
if self._rand_choice not in legal_choices: | |
self._rand_choice = np.random.choice(legal_choices) | |
return self._rand_choice | |
return [select_action(x) for x in obs_batch["action_mask"]], [], {} | |
class LinearHeuristic(HeuristicBase): | |
""" | |
Pick a random move and increment column index | |
""" | |
_rand_choice = random.choice(range(7)) | |
_rand_sign = np.random.choice([-1, 1]) | |
def _do_compute_actions(self, obs_batch): | |
def select_action(legal_action): | |
legal_choices = np.arange(len(legal_action))[legal_action == 1] | |
self._rand_choice += 1 * self._rand_sign | |
if self._rand_choice not in legal_choices: | |
self._rand_choice = np.random.choice(legal_choices) | |
return self._rand_choice | |
return [select_action(x) for x in obs_batch["action_mask"]], [], {} | |
class BeatLastHeuristic(HeuristicBase): | |
""" | |
Play the move the last move of the opponent. | |
""" | |
def _do_compute_actions(self, obs_batch): | |
def select_action(legal_action, observation): | |
legal_choices = np.arange(len(legal_action))[legal_action == 1] | |
obs_sums = np.sum(observation, axis=0) | |
desired_actions = np.squeeze(np.argwhere(obs_sums[:, 0] < obs_sums[:, 1])) | |
if desired_actions.size == 0: | |
return np.random.choice(legal_choices) | |
if desired_actions.size == 1: | |
desired_action = desired_actions[()] | |
else: | |
desired_action = np.random.choice(desired_actions) | |
if desired_action in legal_choices: | |
return desired_action | |
return np.random.choice(legal_choices) | |
return ( | |
[ | |
select_action(x, y) | |
for x, y in zip(obs_batch["action_mask"], obs_batch["observation"]) | |
], | |
[], | |
{}, | |
) | |
class RandomHeuristic(HeuristicBase): | |
""" | |
Just pick a random legal action | |
The outputted state of the environment needs to be a dictionary with an | |
'action_mask' key containing the legal actions for the agent. | |
""" | |
def _do_compute_actions(self, obs_batch): | |
return [self.pick_legal_action(x) for x in obs_batch["action_mask"]], [], {} | |