DQN playing SpaceInvadersNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/0511de345b17175b7cf1ea706c3e05981f11761c
938b20d
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html) | |
import os | |
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1" | |
from multiprocessing import Pool | |
from rl_algo_impls.runner.running_utils import base_parser | |
from rl_algo_impls.runner.train import TrainArgs | |
from rl_algo_impls.runner.train import train as runner_train | |
def train() -> None: | |
parser = base_parser() | |
parser.add_argument( | |
"--wandb-project-name", | |
type=str, | |
default="rl-algo-impls", | |
help="WandB project name to upload training data to. If none, won't upload.", | |
) | |
parser.add_argument( | |
"--wandb-entity", | |
type=str, | |
default=None, | |
help="WandB team of project. None uses default entity", | |
) | |
parser.add_argument( | |
"--wandb-tags", type=str, nargs="*", help="WandB tags to add to run" | |
) | |
parser.add_argument( | |
"--pool-size", type=int, default=1, help="Simultaneous training jobs to run" | |
) | |
parser.add_argument( | |
"--virtual-display", action="store_true", help="Use headless virtual display" | |
) | |
# parser.set_defaults( | |
# algo=["ppo"], | |
# env=["CartPole-v1"], | |
# seed=[10], | |
# pool_size=3, | |
# ) | |
args = parser.parse_args() | |
print(args) | |
if args.virtual_display: | |
from pyvirtualdisplay.display import Display | |
virtual_display = Display(visible=False, size=(1400, 900)) | |
virtual_display.start() | |
# virtual_display isn't a TrainArg so must be removed | |
delattr(args, "virtual_display") | |
pool_size = min(args.pool_size, len(args.seed)) | |
# pool_size isn't a TrainArg so must be removed from args | |
delattr(args, "pool_size") | |
train_args = TrainArgs.expand_from_dict(vars(args)) | |
if len(train_args) == 1: | |
runner_train(train_args[0]) | |
else: | |
# Force a new process for each job to get around wandb not allowing more than one | |
# wandb.tensorboard.patch call per process. | |
with Pool(pool_size, maxtasksperchild=1) as p: | |
p.map(runner_train, train_args) | |
if __name__ == "__main__": | |
train() | |