connectfour / connectfour /training /dummy_policies.py
ClementBM's picture
first commit
ffe7549
import numpy as np
import random
from ray.rllib.policy.policy import Policy
from ray.rllib.utils.annotations import override
from ray.rllib.models.modelv2 import restore_original_dimensions
class HeuristicBase(Policy):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.exploration = self._create_exploration()
def learn_on_batch(self, samples):
pass
@override(Policy)
def get_weights(self):
"""No weights to save."""
return {}
@override(Policy)
def set_weights(self, weights):
"""No weights to set."""
pass
@override(Policy)
def compute_actions(
self,
obs_batch,
state_batches=None,
prev_action_batch=None,
prev_reward_batch=None,
info_batch=None,
episodes=None,
**kwargs
):
obs_batch = restore_original_dimensions(
np.array(obs_batch, dtype=np.float32), self.observation_space, tensorlib=np
)
return self._do_compute_actions(obs_batch)
def pick_legal_action(self, legal_action):
legal_choices = np.arange(len(legal_action))[legal_action == 1]
return np.random.choice(legal_choices)
class AlwaysSameHeuristic(HeuristicBase):
"""
Pick a random move and stick with it for the entire episode.
"""
_rand_choice = random.choice(range(7))
def _do_compute_actions(self, obs_batch):
def select_action(legal_action):
legal_choices = np.arange(len(legal_action))[legal_action == 1]
if self._rand_choice not in legal_choices:
self._rand_choice = np.random.choice(legal_choices)
return self._rand_choice
return [select_action(x) for x in obs_batch["action_mask"]], [], {}
class LinearHeuristic(HeuristicBase):
"""
Pick a random move and increment column index
"""
_rand_choice = random.choice(range(7))
_rand_sign = np.random.choice([-1, 1])
def _do_compute_actions(self, obs_batch):
def select_action(legal_action):
legal_choices = np.arange(len(legal_action))[legal_action == 1]
self._rand_choice += 1 * self._rand_sign
if self._rand_choice not in legal_choices:
self._rand_choice = np.random.choice(legal_choices)
return self._rand_choice
return [select_action(x) for x in obs_batch["action_mask"]], [], {}
class BeatLastHeuristic(HeuristicBase):
"""
Play the move the last move of the opponent.
"""
def _do_compute_actions(self, obs_batch):
def select_action(legal_action, observation):
legal_choices = np.arange(len(legal_action))[legal_action == 1]
obs_sums = np.sum(observation, axis=0)
desired_actions = np.squeeze(np.argwhere(obs_sums[:, 0] < obs_sums[:, 1]))
if desired_actions.size == 0:
return np.random.choice(legal_choices)
if desired_actions.size == 1:
desired_action = desired_actions[()]
else:
desired_action = np.random.choice(desired_actions)
if desired_action in legal_choices:
return desired_action
return np.random.choice(legal_choices)
return (
[
select_action(x, y)
for x, y in zip(obs_batch["action_mask"], obs_batch["observation"])
],
[],
{},
)
class RandomHeuristic(HeuristicBase):
"""
Just pick a random legal action
The outputted state of the environment needs to be a dictionary with an
'action_mask' key containing the legal actions for the agent.
"""
def _do_compute_actions(self, obs_batch):
return [self.pick_legal_action(x) for x in obs_batch["action_mask"]], [], {}