diambra-agent-example / agent-dreamer_v3.py
Michele Milesi
feat: update
df7e1ca
raw
history blame contribute delete
No virus
4.28 kB
import argparse
import json
import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.dreamer_v3.agent import build_agent
from sheeprl.algos.dreamer_v3.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict
"""This is an example agent based on SheepRL.
Usage:
cd sheeprl
diambra run python agent-dreamer_v3.py --cfg_path "./results/dreamer_v3/config.yaml" --checkpoint_path "./results/dreamer_v3/ckpt_1024_0.ckpt"
"""
def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
# Override configs for evaluation
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
cfg.env.capture_video = False
# Only one environment is used for evaluation
cfg.env.num_envs = 1
# Instantiate Fabric
# You must use the same precision and plugins used for training.
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)
# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder
# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
# Check for each algorithm what models the `build_agent()` function returns
# (placed in the `agent.py` file of the algorithm), and which arguments it needs.
# Check also which are the keys of the checkpoint: if the `build_agent()` parameter
# is called `model_state`, then you retrieve the model state with `state["model"]`.
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
world_model_state=state["world_model"],
actor_state=state["actor"],
critic_state=state["critic"],
target_critic_state=state["target_critic"],
)[-1]
agent.eval()
# Print policy network architecture
print("Policy architecture:")
print(agent)
obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()
while True:
# Convert numpy observations into torch observations and normalize image observations
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
# Select actions, the agent returns a one-hot categorical or
# more one-hot categorical distributions for muli-discrete actions space
actions = agent.get_actions(torch_obs, greedy=False)
# Convert actions from one-hot categorical to categorial
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
obs, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)
if terminated or truncated:
obs, info = env.reset()
# Every time you reset the environment, you must reset the initial states of the model
agent.init_states()
if info["env_done"] or test is True:
break
# Close the environment
env.close()
# Return success
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)
main(opt.cfg_path, opt.checkpoint_path, opt.test)