File size: 1,768 Bytes
b638440 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from rl_algo_impls.runner.running_utils import base_parser
from rl_algo_impls.runner.selfplay_evaluate import SelfplayEvalArgs, selfplay_evaluate
def selfplay_enjoy() -> None:
parser = base_parser(multiple=False)
parser.add_argument(
"--wandb-run-paths",
type=str,
nargs="*",
help="WandB run paths to load players from. Must be 0 or 2",
)
parser.add_argument(
"--model-file-paths",
type=str,
help="File paths to load players from. Must be 0 or 2",
)
parser.add_argument("--render", action="store_true")
parser.add_argument("--n-envs", default=1, type=int)
parser.add_argument("--n-episodes", default=1, type=int)
parser.add_argument("--deterministic-eval", default=None, type=bool)
parser.add_argument(
"--no-print-returns", action="store_true", help="Limit printing"
)
parser.add_argument(
"--video-path", type=str, help="Path to save video of all plays"
)
# parser.set_defaults(
# algo=["ppo"],
# env=["Microrts-selfplay-unet-decay"],
# n_episodes=10,
# model_file_paths=[
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S3-best",
# "downloaded_models/ppo-Microrts-selfplay-unet-decay-S2-best",
# ],
# video_path="/Users/sgoodfriend/Desktop/decay3-vs-decay2",
# )
args = parser.parse_args()
args.algo = args.algo[0]
args.env = args.env[0]
args.seed = args.seed[0]
args = SelfplayEvalArgs(**vars(args))
selfplay_evaluate(args, os.getcwd())
if __name__ == "__main__":
selfplay_enjoy()
|