File size: 1,743 Bytes
0e936e1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
from typing import TypeVar
import numpy as np
from gym_microrts.envs.vec_env import MicroRTSGridModeVecEnv
from jpype.types import JArray, JInt
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnvStepReturn
MicroRTSGridModeVecEnvCompatSelf = TypeVar(
"MicroRTSGridModeVecEnvCompatSelf", bound="MicroRTSGridModeVecEnvCompat"
)
class MicroRTSGridModeVecEnvCompat(MicroRTSGridModeVecEnv):
def step(self, action: np.ndarray) -> VecEnvStepReturn:
indexed_actions = np.concatenate(
[
np.expand_dims(
np.stack(
[np.arange(0, action.shape[1]) for i in range(self.num_envs)]
),
axis=2,
),
action,
],
axis=2,
)
action_mask = np.array(self.vec_client.getMasks(0), dtype=np.bool8).reshape(
indexed_actions.shape[:-1] + (-1,)
)
valid_action_mask = action_mask[:, :, 0]
valid_actions_counts = valid_action_mask.sum(1)
valid_actions = indexed_actions[valid_action_mask]
valid_actions_idx = 0
all_valid_actions = []
for env_act_cnt in valid_actions_counts:
env_valid_actions = []
for _ in range(env_act_cnt):
env_valid_actions.append(JArray(JInt)(valid_actions[valid_actions_idx]))
valid_actions_idx += 1
all_valid_actions.append(JArray(JArray(JInt))(env_valid_actions))
return super().step(JArray(JArray(JArray(JInt)))(all_valid_actions)) # type: ignore
@property
def unwrapped(
self: MicroRTSGridModeVecEnvCompatSelf,
) -> MicroRTSGridModeVecEnvCompatSelf:
return self
|