File size: 3,377 Bytes
e11e4fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
from typing import Dict, List

from mlagents_envs.base_env import BaseEnv, BehaviorName, BehaviorSpec
from mlagents.trainers.env_manager import EnvManager, EnvironmentStep, AllStepResult
from mlagents_envs.timers import timed
from mlagents.trainers.action_info import ActionInfo
from mlagents.trainers.settings import ParameterRandomizationSettings
from mlagents_envs.side_channel.environment_parameters_channel import (
    EnvironmentParametersChannel,
)


class SimpleEnvManager(EnvManager):
    """
    Simple implementation of the EnvManager interface that only handles one BaseEnv at a time.
    This is generally only useful for testing; see SubprocessEnvManager for a production-quality implementation.
    """

    def __init__(self, env: BaseEnv, env_params: EnvironmentParametersChannel):
        super().__init__()
        self.env_params = env_params
        self.env = env
        self.previous_step: EnvironmentStep = EnvironmentStep.empty(0)
        self.previous_all_action_info: Dict[str, ActionInfo] = {}

    def _step(self) -> List[EnvironmentStep]:
        all_action_info = self._take_step(self.previous_step)
        self.previous_all_action_info = all_action_info

        for brain_name, action_info in all_action_info.items():
            self.env.set_actions(brain_name, action_info.env_action)
        self.env.step()
        all_step_result = self._generate_all_results()

        step_info = EnvironmentStep(
            all_step_result, 0, self.previous_all_action_info, {}
        )
        self.previous_step = step_info
        return [step_info]

    def _reset_env(
        self, config: Dict[BehaviorName, float] = None
    ) -> List[EnvironmentStep]:  # type: ignore
        self.set_env_parameters(config)
        self.env.reset()
        all_step_result = self._generate_all_results()
        self.previous_step = EnvironmentStep(all_step_result, 0, {}, {})
        return [self.previous_step]

    def set_env_parameters(self, config: Dict = None) -> None:
        """
        Sends environment parameter settings to C# via the
        EnvironmentParametersSidehannel.
        :param config: Dict of environment parameter keys and values
        """
        if config is not None:
            for k, v in config.items():
                if isinstance(v, float):
                    self.env_params.set_float_parameter(k, v)
                elif isinstance(v, ParameterRandomizationSettings):
                    v.apply(k, self.env_params)

    @property
    def training_behaviors(self) -> Dict[BehaviorName, BehaviorSpec]:
        return self.env.behavior_specs

    def close(self):
        self.env.close()

    @timed
    def _take_step(self, last_step: EnvironmentStep) -> Dict[BehaviorName, ActionInfo]:
        all_action_info: Dict[str, ActionInfo] = {}
        for brain_name, step_tuple in last_step.current_all_step_result.items():
            all_action_info[brain_name] = self.policies[brain_name].get_action(
                step_tuple[0],
                0,  # As there is only one worker, we assign the worker_id to 0.
            )
        return all_action_info

    def _generate_all_results(self) -> AllStepResult:
        all_step_result: AllStepResult = {}
        for brain_name in self.env.behavior_specs:
            all_step_result[brain_name] = self.env.get_steps(brain_name)
        return all_step_result