File size: 1,289 Bytes
0e936e1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from typing import Optional, Union

import numpy as np

from rl_algo_impls.wrappers.vectorable_wrapper import (
    VecEnv,
    VecotarableWrapper,
    find_wrapper,
)


class IncompleteArrayError(Exception):
    pass


class SingleActionMaskWrapper(VecotarableWrapper):
    def action_masks(self) -> Optional[np.ndarray]:
        envs = getattr(self.env.unwrapped, "envs")
        assert (
            envs
        ), f"{self.__class__.__name__} expects to wrap synchronous vectorized env"
        masks = [getattr(e.unwrapped, "action_mask") for e in envs]
        assert all(m is not None for m in masks)
        return np.array(masks, dtype=np.bool8)


class MicrortsMaskWrapper(VecotarableWrapper):
    def action_masks(self) -> np.ndarray:
        microrts_env = self.env.unwrapped  # type: ignore
        vec_client = getattr(microrts_env, "vec_client")
        assert (
            vec_client
        ), f"{microrts_env.__class__.__name__} must have vec_client property (as MicroRTSVecEnv does)"
        return np.array(vec_client.getMasks(0), dtype=np.bool8)


def find_action_masker(
    env: VecEnv,
) -> Optional[Union[SingleActionMaskWrapper, MicrortsMaskWrapper]]:
    return find_wrapper(env, SingleActionMaskWrapper) or find_wrapper(
        env, MicrortsMaskWrapper
    )