diambra-agent-example / agent-ppo.py
michele-milesi's picture
Rename ppo-agent.py to agent-ppo.py
12794d0 verified
raw
history blame
3.58 kB
import argparse
import json
import gymnasium as gym
import torch
from lightning import Fabric
from omegaconf import OmegaConf
from sheeprl.algos.ppo.agent import build_agent
from sheeprl.utils.env import make_env
from sheeprl.utils.utils import dotdict
"""This is an example agent based on SheepRL.
Usage:
diambra run python sheeprl/agent.py --cfg_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt"
"""
def main(cfg_path: str, checkpoint_path: str, test=False):
# Read the cfg file
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True))
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4))
# Override configs for evaluation
if not test:
cfg.env.capture_video = False
cfg.env.num_envs = 1
# Instantiate Fabric
precision = getattr(cfg.fabric, "precision", None)
plugins = getattr(cfg.fabric, "plugins", None)
fabric = Fabric(
accelerator="auto",
devices=1,
num_nodes=1,
precision=precision,
plugins=plugins,
strategy="auto",
)
# Create Environment
env = make_env(cfg, 0, 0)()
observation_space = env.observation_space
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete)
actions_dim = tuple(
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n]
)
cnn_keys = cfg.algo.cnn_keys.encoder
mlp_keys = cfg.algo.mlp_keys.encoder
obs_keys = mlp_keys + cnn_keys
# Load the trained agent
state = fabric.load(checkpoint_path)
# You need to retrieve only the player
agent = build_agent(
fabric=fabric,
actions_dim=actions_dim,
is_continuous=False,
cfg=cfg,
obs_space=observation_space,
agent_state=state["agent"],
)[-1]
agent.eval()
# Print policy network architecture
print("Policy architecture:")
print(agent)
o, info = env.reset()
while True:
# Convert numpy observations into torch observations and normalize image observations
obs = {}
for k in o.keys():
if k in obs_keys:
torch_obs = torch.from_numpy(o[k].copy()).to(fabric.device).unsqueeze(0)
if k in cnn_keys:
torch_obs = (
torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5
)
if k in mlp_keys:
torch_obs = torch_obs.float()
obs[k] = torch_obs
actions = agent.get_actions(obs, greedy=True)
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1)
o, _, terminated, truncated, info = env.step(
actions.cpu().numpy().reshape(env.action_space.shape)
)
if terminated or truncated:
o, info = env.reset()
if info["env_done"] or test is True:
break
# Close the environment
env.close()
# Return success
return 0
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--cfg_path", type=str, required=True, help="Configuration file"
)
parser.add_argument(
"--checkpoint_path", type=str, default="model", help="Model checkpoint"
)
parser.add_argument("--test", action="store_true", help="Test mode")
opt = parser.parse_args()
print(opt)
main(opt.cfg_path, opt.checkpoint_path, opt.test)