|
import argparse |
|
import json |
|
|
|
import gymnasium as gym |
|
import torch |
|
from lightning import Fabric |
|
from omegaconf import OmegaConf |
|
from sheeprl.algos.ppo.agent import build_agent |
|
from sheeprl.utils.env import make_env |
|
from sheeprl.utils.utils import dotdict |
|
|
|
"""This is an example agent based on SheepRL. |
|
|
|
Usage: |
|
diambra run python sheeprl/agent.py --cfg_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/config.yaml" --checkpoint_path "./fake-logs/runs/ppo/doapp/fake-experiment/version_0/checkpoint/ckpt_1024_0.ckpt" |
|
""" |
|
|
|
|
|
def main(cfg_path: str, checkpoint_path: str, test=False): |
|
|
|
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)) |
|
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4)) |
|
|
|
|
|
if not test: |
|
cfg.env.capture_video = False |
|
cfg.env.num_envs = 1 |
|
|
|
|
|
precision = getattr(cfg.fabric, "precision", None) |
|
plugins = getattr(cfg.fabric, "plugins", None) |
|
fabric = Fabric( |
|
accelerator="auto", |
|
devices=1, |
|
num_nodes=1, |
|
precision=precision, |
|
plugins=plugins, |
|
strategy="auto", |
|
) |
|
|
|
|
|
env = make_env(cfg, 0, 0)() |
|
observation_space = env.observation_space |
|
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) |
|
actions_dim = tuple( |
|
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n] |
|
) |
|
cnn_keys = cfg.algo.cnn_keys.encoder |
|
mlp_keys = cfg.algo.mlp_keys.encoder |
|
obs_keys = mlp_keys + cnn_keys |
|
|
|
|
|
state = fabric.load(checkpoint_path) |
|
|
|
agent = build_agent( |
|
fabric=fabric, |
|
actions_dim=actions_dim, |
|
is_continuous=False, |
|
cfg=cfg, |
|
obs_space=observation_space, |
|
agent_state=state["agent"], |
|
)[-1] |
|
agent.eval() |
|
|
|
|
|
print("Policy architecture:") |
|
print(agent) |
|
|
|
o, info = env.reset() |
|
|
|
while True: |
|
|
|
obs = {} |
|
for k in o.keys(): |
|
if k in obs_keys: |
|
torch_obs = torch.from_numpy(o[k].copy()).to(fabric.device).unsqueeze(0) |
|
if k in cnn_keys: |
|
torch_obs = ( |
|
torch_obs.reshape(1, -1, *torch_obs.shape[-2:]) / 255 - 0.5 |
|
) |
|
if k in mlp_keys: |
|
torch_obs = torch_obs.float() |
|
obs[k] = torch_obs |
|
|
|
actions = agent.get_actions(obs, greedy=True) |
|
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) |
|
|
|
o, _, terminated, truncated, info = env.step( |
|
actions.cpu().numpy().reshape(env.action_space.shape) |
|
) |
|
|
|
if terminated or truncated: |
|
o, info = env.reset() |
|
if info["env_done"] or test is True: |
|
break |
|
|
|
|
|
env.close() |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--cfg_path", type=str, required=True, help="Configuration file" |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", type=str, default="model", help="Model checkpoint" |
|
) |
|
parser.add_argument("--test", action="store_true", help="Test mode") |
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
main(opt.cfg_path, opt.checkpoint_path, opt.test) |
|
|