File size: 4,280 Bytes
6b39341 4d140f3 6b39341 eeec377 6b39341 df7e1ca 6b39341 e184661 24c3670 6b39341 e184661 6b39341 4d140f3 6b39341 e184661 6b39341 e184661 24c3670 41306ef 6b39341 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import argparse
import json
import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict
"""This is an example agent based on SheepRL.
Usage:
cd sheeprl
diambra run python agent-dreamer_v3.py --cfg_path "./results/dreamer_v3/config.yaml" --checkpoint_path "./results/dreamer_v3/ckpt_1024_0.ckpt"
"""
def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
# Override configs for evaluation
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
cfg.env.capture_video = False
# Only one environment is used for evaluation
cfg.env.num_envs = 1
# Instantiate Fabric
# You must use the same precision and plugins used for training.
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)
# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder
# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
# Check for each algorithm what models the `build_agent()` function returns
# (placed in the `agent.py` file of the algorithm), and which arguments it needs.
# Check also which are the keys of the checkpoint: if the `build_agent()` parameter
# is called `model_state`, then you retrieve the model state with `state["model"]`.
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
world_model_state=state["world_model"],
actor_state=state["actor"],
critic_state=state["critic"],
target_critic_state=state["target_critic"],
)[-1]
agent.eval()
# Print policy network architecture
print("Policy architecture:")
print(agent)
obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()
while True:
# Convert numpy observations into torch observations and normalize image observations
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
# Select actions, the agent returns a one-hot categorical or
# more one-hot categorical distributions for muli-discrete actions space
actions = agent.get_actions(torch_obs, greedy=False)
# Convert actions from one-hot categorical to categorial
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
obs, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)
if terminated or truncated:
obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()
if info["env_done"] or test is True:
break
# Close the environment
env.close()
# Return success
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)
main(opt.cfg_path, opt.checkpoint_path, opt.test)
|