File size: 2,181 Bytes
3cc5c1d
1e1c086
 
 
 
 
 
 
 
 
 
 
3cc5c1d
 
 
1e1c086
 
 
3cc5c1d
 
 
 
 
 
 
1e1c086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3cc5c1d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1e1c086
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from typing import Any, Dict, List, Optional

import numpy as np

from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnvObs,
    VecEnvStepReturn,
    VecotarableWrapper,
)


class MicrortsStatsRecorder(VecotarableWrapper):
    def __init__(
        self, env, gamma: float, bots: Optional[Dict[str, int]] = None
    ) -> None:
        super().__init__(env)
        self.gamma = gamma
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        self.bots = bots
        if self.bots:
            self.bot_at_index = [None] * (env.num_envs - sum(self.bots.values()))
            for b, n in self.bots.items():
                self.bot_at_index.extend([b] * n)
        else:
            self.bot_at_index = [None] * env.num_envs

    def reset(self) -> VecEnvObs:
        obs = super().reset()
        self.raw_rewards = [[] for _ in range(self.num_envs)]
        return obs

    def step(self, actions: np.ndarray) -> VecEnvStepReturn:
        obs, rews, dones, infos = self.env.step(actions)
        self._update_infos(infos, dones)
        return obs, rews, dones, infos

    def _update_infos(self, infos: List[Dict[str, Any]], dones: np.ndarray) -> None:
        for idx, info in enumerate(infos):
            self.raw_rewards[idx].append(info["raw_rewards"])
        for idx, (info, done) in enumerate(zip(infos, dones)):
            if done:
                raw_rewards = np.array(self.raw_rewards[idx]).sum(0)
                raw_names = [str(rf) for rf in self.env.unwrapped.rfs]
                info["microrts_stats"] = dict(zip(raw_names, raw_rewards))

                winloss = raw_rewards[raw_names.index("WinLossRewardFunction")]
                microrts_results = {
                    "win": int(winloss == 1),
                    "draw": int(winloss == 0),
                    "loss": int(winloss == -1),
                }
                bot = self.bot_at_index[idx]
                if bot:
                    microrts_results.update(
                        {f"{k}_{bot}": v for k, v in microrts_results.items()}
                    )

                info["microrts_results"] = microrts_results

                self.raw_rewards[idx] = []