File size: 4,280 Bytes
6b39341
 
 
 
 
 
 
 
4d140f3
6b39341
 
 
 
 
 
 
eeec377
6b39341
 
 
 
 
 
 
 
 
 
 
df7e1ca
 
6b39341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e184661
24c3670
 
6b39341
 
 
e184661
6b39341
 
 
4d140f3
6b39341
 
 
e184661
6b39341
 
 
 
e184661
24c3670
41306ef
6b39341
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import argparse
import json

import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict

"""This is an example agent based on SheepRL.

Usage:
cd sheeprl
diambra run python agent-dreamer_v3.py --cfg_path "./results/dreamer_v3/config.yaml" --checkpoint_path "./results/dreamer_v3/ckpt_1024_0.ckpt"
"""


def main(cfg_path: str, checkpoint_path: str, test=False):
    # Read the cfg file
    cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
    print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))

    # Override configs for evaluation
    # You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
    cfg.env.capture_video = False
    # Only one environment is used for evaluation
    cfg.env.num_envs = 1

    # Instantiate Fabric
    # You must use the same precision and plugins used for training.
    precision = getattr(cfg.fabric, "precision", None)
    plugins = getattr(cfg.fabric, "plugins", None)
    fabric = Fabric(
        accelerator="auto",
        devices=1,
        num_nodes=1,
        precision=precision,
        plugins=plugins,
        strategy="auto",
    )

    # Create Environment
    env = make_env(cfg, 0, 0)()
    observation_space = env.observation_space
    is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
    actions_dim = tuple(
        env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
    )
    cnn_keys = cfg.algo.cnn_keys.encoder

    # Load the trained agent
    state = fabric.load(checkpoint_path)
    # You need to retrieve only the player
    # Check for each algorithm what models the `build_agent()` function returns
    # (placed in the `agent.py` file of the algorithm), and which arguments it needs.
    # Check also which are the keys of the checkpoint: if the `build_agent()` parameter
    # is called `model_state`, then you retrieve the model state with `state["model"]`.
    agent = build_agent(
        fabric=fabric,
        actions_dim=actions_dim,
        is_continuous=False,
        cfg=cfg,
        obs_space=observation_space,
        world_model_state=state["world_model"],
        actor_state=state["actor"],
        critic_state=state["critic"],
        target_critic_state=state["target_critic"],
    )[-1]
    agent.eval()

    # Print policy network architecture
    print("Policy architecture:")
    print(agent)

    obs, info = env.reset()
    # Every time you reset the environment, you must reset the initial states of the model
    agent.init_states()

    while True:
        # Convert numpy observations into torch observations and normalize image observations
        torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)

        # Select actions, the agent returns a one-hot categorical or
        # more one-hot categorical distributions for muli-discrete actions space
        actions = agent.get_actions(torch_obs, greedy=False)
        # Convert actions from one-hot categorical to categorial
        actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)

        obs, _, terminated, truncated, info = env.step(
            actions.cpu().numpy().reshape(env.action_space.shape)
        )

        if terminated or truncated:
            obs, info = env.reset()
            # Every time you reset the environment, you must reset the initial states of the model
            agent.init_states()
            if info["env_done"] or test is True:
                break

    # Close the environment
    env.close()

    # Return success
    return 0


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--cfg_path", type=str, required=True, help="Configuration file"
    )
    parser.add_argument(
        "--checkpoint_path", type=str, default="model", help="Model checkpoint"
    )
    parser.add_argument("--test", action="store_true", help="Test mode")
    opt = parser.parse_args()
    print(opt)

    main(opt.cfg_path, opt.checkpoint_path, opt.test)