File size: 4,403 Bytes
ffe7549
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from typing import Optional

from ray.rllib.env.multi_agent_env import MultiAgentEnv
from ray.rllib.utils.annotations import PublicAPI
from ray.rllib.utils.gym import convert_old_gym_space_to_gymnasium_space


@PublicAPI
class Connect4Env(MultiAgentEnv):
    """An interface to the PettingZoo MARL environment library.

    See: https://github.com/Farama-Foundation/PettingZoo

    Inherits from MultiAgentEnv and exposes a given AEC
    (actor-environment-cycle) game from the PettingZoo project via the
    MultiAgentEnv public API.

    Note that the wrapper has some important limitations:

    1. All agents have the same action_spaces and observation_spaces.
       Note: If, within your aec game, agents do not have homogeneous action /
       observation spaces, apply SuperSuit wrappers
       to apply padding functionality: https://github.com/Farama-Foundation/
       SuperSuit#built-in-multi-agent-only-functions
    2. Environments are positive sum games (-> Agents are expected to cooperate
       to maximize reward). This isn't a hard restriction, it just that
       standard algorithms aren't expected to work well in highly competitive
       games."""

    def __init__(self, env):
        super().__init__()
        self.env = env
        env.reset()

        # Since all agents have the same spaces, do not provide full observation-
        # and action-spaces as Dicts, mapping agent IDs to the individual
        # agents' spaces. Instead, `self.[action|observation]_space` are the single
        # agent spaces.
        self._obs_space_in_preferred_format = False
        self._action_space_in_preferred_format = False

        # Collect the individual agents' spaces (they should all be the same):
        first_obs_space = self.env.observation_space(self.env.agents[0])
        first_action_space = self.env.action_space(self.env.agents[0])

        for agent in self.env.agents:
            if self.env.observation_space(agent) != first_obs_space:
                raise ValueError(
                    "Observation spaces for all agents must be identical. Perhaps "
                    "SuperSuit's pad_observations wrapper can help (useage: "
                    "`supersuit.aec_wrappers.pad_observations(env)`"
                )
            if self.env.action_space(agent) != first_action_space:
                raise ValueError(
                    "Action spaces for all agents must be identical. Perhaps "
                    "SuperSuit's pad_action_space wrapper can help (usage: "
                    "`supersuit.aec_wrappers.pad_action_space(env)`"
                )

        # Convert from gym to gymnasium, if necessary.
        self.observation_space = convert_old_gym_space_to_gymnasium_space(
            first_obs_space
        )
        self.action_space = convert_old_gym_space_to_gymnasium_space(first_action_space)

        self._agent_ids = set(self.env.agents)

    def reset(self, *, seed: Optional[int] = None, options: Optional[dict] = None):
        info = self.env.reset(seed=seed, options=options)
        return (
            {self.env.agent_selection: self.env.observe(self.env.agent_selection)},
            info or {},
        )

    def step(self, action):
        self.env.step(action[self.env.agent_selection])
        obs_d = {}
        rew_d = {}
        terminated_d = {}
        truncated_d = {}
        info_d = {}
        while self.env.agents:
            obs, rew, terminated, truncated, info = self.env.last()
            agent_id = self.env.agent_selection
            obs_d[agent_id] = obs
            rew_d[agent_id] = rew
            terminated_d[agent_id] = terminated
            truncated_d[agent_id] = truncated
            info_d[agent_id] = info
            if (
                self.env.terminations[self.env.agent_selection]
                or self.env.truncations[self.env.agent_selection]
            ):
                self.env.step(None)
            else:
                break

        all_gone = not self.env.agents
        terminated_d["__all__"] = all_gone and all(terminated_d.values())
        truncated_d["__all__"] = all_gone and all(truncated_d.values())

        return obs_d, rew_d, terminated_d, truncated_d, info_d

    def close(self):
        self.env.close()

    def render(self):
        return self.env.render()

    @property
    def get_sub_environments(self):
        return self.env.unwrapped