File size: 2,064 Bytes
1e1c086
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
from dataclasses import asdict
from typing import Optional

from torch.utils.tensorboard.writer import SummaryWriter

from rl_algo_impls.runner.config import Config, EnvHyperparams
from rl_algo_impls.shared.vec_env.microrts import make_microrts_env
from rl_algo_impls.shared.vec_env.procgen import make_procgen_env
from rl_algo_impls.shared.vec_env.vec_env import make_vec_env
from rl_algo_impls.wrappers.vectorable_wrapper import VecEnv


def make_env(
    config: Config,
    hparams: EnvHyperparams,
    training: bool = True,
    render: bool = False,
    normalize_load_path: Optional[str] = None,
    tb_writer: Optional[SummaryWriter] = None,
) -> VecEnv:
    if hparams.env_type == "procgen":
        return make_procgen_env(
            config,
            hparams,
            training=training,
            render=render,
            normalize_load_path=normalize_load_path,
            tb_writer=tb_writer,
        )
    elif hparams.env_type in {"sb3vec", "gymvec"}:
        return make_vec_env(
            config,
            hparams,
            training=training,
            render=render,
            normalize_load_path=normalize_load_path,
            tb_writer=tb_writer,
        )
    elif hparams.env_type == "microrts":
        return make_microrts_env(
            config,
            hparams,
            training=training,
            render=render,
            normalize_load_path=normalize_load_path,
            tb_writer=tb_writer,
        )
    else:
        raise ValueError(f"env_type {hparams.env_type} not supported")


def make_eval_env(
    config: Config,
    hparams: EnvHyperparams,
    override_n_envs: Optional[int] = None,
    **kwargs,
) -> VecEnv:
    kwargs = kwargs.copy()
    kwargs["training"] = False
    if override_n_envs is not None:
        hparams_kwargs = asdict(hparams)
        hparams_kwargs["n_envs"] = override_n_envs
        if override_n_envs == 1:
            hparams_kwargs["vec_env_class"] = "sync"
        hparams = EnvHyperparams(**hparams_kwargs)
    return make_env(config, hparams, **kwargs)