sgoodfriend's picture
A2C playing BreakoutNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/2067e21d62fff5db60168687e7d9e89019a8bfc0
32c4441
raw
history blame
2.08 kB
# Support for PyTorch mps mode (https://pytorch.org/docs/stable/notes/mps.html)
import os
os.environ["PYTORCH_ENABLE_MPS_FALLBACK"] = "1"
from multiprocessing import Pool
from rl_algo_impls.runner.running_utils import base_parser
from rl_algo_impls.runner.train import train as runner_train, TrainArgs
def train() -> None:
parser = base_parser()
parser.add_argument(
"--wandb-project-name",
type=str,
default="rl-algo-impls",
help="WandB project name to upload training data to. If none, won't upload.",
)
parser.add_argument(
"--wandb-entity",
type=str,
default=None,
help="WandB team of project. None uses default entity",
)
parser.add_argument(
"--wandb-tags", type=str, nargs="*", help="WandB tags to add to run"
)
parser.add_argument(
"--pool-size", type=int, default=1, help="Simultaneous training jobs to run"
)
parser.add_argument(
"--virtual-display", action="store_true", help="Use headless virtual display"
)
# parser.set_defaults(
# algo=["ppo"],
# env=["CartPole-v1"],
# seed=[10],
# pool_size=3,
# )
args = parser.parse_args()
print(args)
if args.virtual_display:
from pyvirtualdisplay.display import Display
virtual_display = Display(visible=False, size=(1400, 900))
virtual_display.start()
# virtual_display isn't a TrainArg so must be removed
delattr(args, "virtual_display")
pool_size = min(args.pool_size, len(args.seed))
# pool_size isn't a TrainArg so must be removed from args
delattr(args, "pool_size")
train_args = TrainArgs.expand_from_dict(vars(args))
if len(train_args) == 1:
runner_train(train_args[0])
else:
# Force a new process for each job to get around wandb not allowing more than one
# wandb.tensorboard.patch call per process.
with Pool(pool_size, maxtasksperchild=1) as p:
p.map(runner_train, train_args)
if __name__ == "__main__":
train()