|
import argparse |
|
import json |
|
|
|
import gymnasium as gym |
|
import torch |
|
from lightning import Fabric |
|
from omegaconf import OmegaConf |
|
from sheeprl.algos.dreamer_v3.agent import build_agent |
|
from sheeprl.algos.dreamer_v3.utils import prepare_obs |
|
from sheeprl.utils.env import make_env |
|
from sheeprl.utils.utils import dotdict |
|
|
|
"""This is an example agent based on SheepRL. |
|
|
|
Usage: |
|
cd sheeprl |
|
diambra run python agent-dreamer_v3.py --cfg_path "./results/dreamer_v3/config.yaml" --checkpoint_path "./results/dreamer_v3/ckpt_1024_0.ckpt" |
|
""" |
|
|
|
|
|
def main(cfg_path: str, checkpoint_path: str, test=False): |
|
|
|
cfg = dotdict(OmegaConf.to_container(OmegaConf.load(cfg_path), resolve=True)) |
|
print("Config parameters = ", json.dumps(cfg, sort_keys=True, indent=4)) |
|
|
|
|
|
|
|
cfg.env.capture_video = False |
|
|
|
cfg.env.num_envs = 1 |
|
|
|
|
|
|
|
precision = getattr(cfg.fabric, "precision", None) |
|
plugins = getattr(cfg.fabric, "plugins", None) |
|
fabric = Fabric( |
|
accelerator="auto", |
|
devices=1, |
|
num_nodes=1, |
|
precision=precision, |
|
plugins=plugins, |
|
strategy="auto", |
|
) |
|
|
|
|
|
env = make_env(cfg, 0, 0)() |
|
observation_space = env.observation_space |
|
is_multidiscrete = isinstance(env.action_space, gym.spaces.MultiDiscrete) |
|
actions_dim = tuple( |
|
env.action_space.nvec.tolist() if is_multidiscrete else [env.action_space.n] |
|
) |
|
cnn_keys = cfg.algo.cnn_keys.encoder |
|
|
|
|
|
state = fabric.load(checkpoint_path) |
|
|
|
|
|
|
|
|
|
|
|
agent = build_agent( |
|
fabric=fabric, |
|
actions_dim=actions_dim, |
|
is_continuous=False, |
|
cfg=cfg, |
|
obs_space=observation_space, |
|
world_model_state=state["world_model"], |
|
actor_state=state["actor"], |
|
critic_state=state["critic"], |
|
target_critic_state=state["target_critic"], |
|
)[-1] |
|
agent.eval() |
|
|
|
|
|
print("Policy architecture:") |
|
print(agent) |
|
|
|
obs, info = env.reset() |
|
|
|
agent.init_states() |
|
|
|
while True: |
|
|
|
torch_obs = prepare_obs(fabric, obs, cnn_keys=cnn_keys) |
|
|
|
|
|
|
|
actions = agent.get_actions(torch_obs, greedy=False) |
|
|
|
actions = torch.cat([act.argmax(dim=-1) for act in actions], dim=-1) |
|
|
|
obs, _, terminated, truncated, info = env.step( |
|
actions.cpu().numpy().reshape(env.action_space.shape) |
|
) |
|
|
|
if terminated or truncated: |
|
obs, info = env.reset() |
|
|
|
agent.init_states() |
|
if info["env_done"] or test is True: |
|
break |
|
|
|
|
|
env.close() |
|
|
|
|
|
return 0 |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument( |
|
"--cfg_path", type=str, required=True, help="Configuration file" |
|
) |
|
parser.add_argument( |
|
"--checkpoint_path", type=str, default="model", help="Model checkpoint" |
|
) |
|
parser.add_argument("--test", action="store_true", help="Test mode") |
|
opt = parser.parse_args() |
|
print(opt) |
|
|
|
main(opt.cfg_path, opt.checkpoint_path, opt.test) |
|
|