diambra-agent-example / agent-ppo.py
Michele Milesi
feat: update
df7e1ca
raw
history blame contribute delete
No virus
3.23 kB
import argparse
import json
import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.ppo.agent import build_agent
from sheeprl.algos.ppo.utils import prepare_obs
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict
"""This is an example agent based on SheepRL.
Usage:
diambra run python sheeprl/agent.py --cfg_path "./results/ppo/config.yaml" --checkpoint_path "./results/ppo/ckpt_1024_0.ckpt"
"""
def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
# Override configs for evaluation
# You do not need to capture the video since you are submitting the agent and the video is recorded by DIAMBRA
cfg.env.capture_video = False
# Only one environment is used for evaluation
cfg.env.num_envs = 1
# Instantiate Fabric
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)
# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder
# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
agent_state=state["agent"],
)[-1]
agent.eval()
# Print policy network architecture
print("Policy architecture:")
print(agent)
obs, info = env.reset()
while True:
# Convert numpy observations into torch observations and normalize image observations
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys)
actions = agent.get_actions(torch_obs, greedy=True)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
obs, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)
if terminated or truncated:
obs, info = env.reset()
if info["env_done"] or test is True:
break
# Close the environment
env.close()
# Return success
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)
main(opt.cfg_path, opt.checkpoint_path, opt.test)