# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/dqn/#dqn_atari_jaxpy import argparse import os import random import time from distutils.util import strtobool os.environ[ "XLA_PYTHON_CLIENT_MEM_FRACTION" ] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991 import flax import flax.linen as nn import gym import jax import jax.numpy as jnp import numpy as np import optax from flax.training.train_state import TrainState from stable_baselines3.common.atari_wrappers import ( ClipRewardEnv, EpisodicLifeEnv, FireResetEnv, MaxAndSkipEnv, NoopResetEnv, ) from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter def parse_args(): # fmt: off parser = argparse.ArgumentParser() parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), help="the name of this experiment") parser.add_argument("--seed", type=int, default=1, help="seed of the experiment") parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="if toggled, this experiment will be tracked with Weights and Biases") parser.add_argument("--wandb-project-name", type=str, default="cleanRL", help="the wandb's project name") parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project") parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to capture videos of the agent performances (check out `videos` folder)") parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to save model into the `runs/{run_name}` folder") parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to upload the saved model to huggingface") parser.add_argument("--hf-entity", type=str, default="", help="the user or org name of the model repository from the Hugging Face Hub") # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="BreakoutNoFrameskip-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=10000000, help="total timesteps of the experiments") parser.add_argument("--learning-rate", type=float, default=1e-4, help="the learning rate of the optimizer") parser.add_argument("--buffer-size", type=int, default=1000000, help="the replay memory buffer size") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--target-network-frequency", type=int, default=1000, help="the timesteps it takes to update the target network") parser.add_argument("--batch-size", type=int, default=32, help="the batch size of sample from the reply memory") parser.add_argument("--start-e", type=float, default=1, help="the starting epsilon for exploration") parser.add_argument("--end-e", type=float, default=0.01, help="the ending epsilon for exploration") parser.add_argument("--exploration-fraction", type=float, default=0.10, help="the fraction of `total-timesteps` it takes from start-e to go end-e") parser.add_argument("--learning-starts", type=int, default=80000, help="timestep to start learning") parser.add_argument("--train-frequency", type=int, default=4, help="the frequency of training") args = parser.parse_args() # fmt: on return args def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) if capture_video: if idx == 0: env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") env = NoopResetEnv(env, noop_max=30) env = MaxAndSkipEnv(env, skip=4) env = EpisodicLifeEnv(env) if "FIRE" in env.unwrapped.get_action_meanings(): env = FireResetEnv(env) env = ClipRewardEnv(env) env = gym.wrappers.ResizeObservation(env, (84, 84)) env = gym.wrappers.GrayScaleObservation(env) env = gym.wrappers.FrameStack(env, 4) env.seed(seed) env.action_space.seed(seed) env.observation_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): action_dim: int @nn.compact def __call__(self, x): x = jnp.transpose(x, (0, 2, 3, 1)) x = x / (255.0) x = nn.Conv(32, kernel_size=(8, 8), strides=(4, 4), padding="VALID")(x) x = nn.relu(x) x = nn.Conv(64, kernel_size=(4, 4), strides=(2, 2), padding="VALID")(x) x = nn.relu(x) x = nn.Conv(64, kernel_size=(3, 3), strides=(1, 1), padding="VALID")(x) x = nn.relu(x) x = x.reshape((x.shape[0], -1)) x = nn.Dense(512)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) return x class TrainState(TrainState): target_params: flax.core.FrozenDict def linear_schedule(start_e: float, end_e: float, duration: int, t: int): slope = (end_e - start_e) / duration return max(slope * t + start_e, end_e) if __name__ == "__main__": args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=run_name, monitor_gym=True, save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, q_key = jax.random.split(key, 2) # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" obs = envs.reset() q_network = QNetwork(action_dim=envs.single_action_space.n) q_state = TrainState.create( apply_fn=q_network.apply, params=q_network.init(q_key, obs), target_params=q_network.init(q_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) q_network.apply = jax.jit(q_network.apply) # This step is not necessary as init called on same observation and key will always lead to same initializations q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, "cpu", optimize_memory_usage=True, handle_timeout_termination=True, ) @jax.jit def update(q_state, observations, actions, next_observations, rewards, dones): q_next_target = q_network.apply(q_state.target_params, next_observations) # (batch_size, num_actions) q_next_target = jnp.max(q_next_target, axis=-1) # (batch_size,) next_q_value = rewards + (1 - dones) * args.gamma * q_next_target def mse_loss(params): q_pred = q_network.apply(params, observations) # (batch_size, num_actions) q_pred = q_pred[np.arange(q_pred.shape[0]), actions.squeeze()] # (batch_size,) return ((q_pred - next_q_value) ** 2).mean(), q_pred (loss_value, q_pred), grads = jax.value_and_grad(mse_loss, has_aux=True)(q_state.params) q_state = q_state.apply_gradients(grads=grads) return loss_value, q_pred, q_state start_time = time.time() # TRY NOT TO MODIFY: start the game obs = envs.reset() for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here epsilon = linear_schedule(args.start_e, args.end_e, args.exploration_fraction * args.total_timesteps, global_step) if random.random() < epsilon: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: q_values = q_network.apply(q_state.params, obs) actions = q_values.argmax(axis=-1) actions = jax.device_get(actions) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, dones, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes for info in infos: if "episode" in info.keys(): print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) writer.add_scalar("charts/epsilon", epsilon, global_step) break # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation` real_next_obs = next_obs.copy() for idx, d in enumerate(dones): if d: real_next_obs[idx] = infos[idx]["terminal_observation"] rb.add(obs, real_next_obs, actions, rewards, dones, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: if global_step % args.train_frequency == 0: data = rb.sample(args.batch_size) # perform a gradient-descent step loss, old_val, q_state = update( q_state, data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.flatten().numpy(), data.dones.flatten().numpy(), ) if global_step % 100 == 0: writer.add_scalar("losses/td_loss", jax.device_get(loss), global_step) writer.add_scalar("losses/q_values", jax.device_get(old_val).mean(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) # update the target network if global_step % args.target_network_frequency == 0: q_state = q_state.replace(target_params=optax.incremental_update(q_state.params, q_state.target_params, 1)) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" with open(model_path, "wb") as f: f.write(flax.serialization.to_bytes(q_state.params)) print(f"model saved to {model_path}") from cleanrl_utils.evals.dqn_jax_eval import evaluate episodic_returns = evaluate( model_path, make_env, args.env_id, eval_episodes=10, run_name=f"{run_name}-eval", Model=QNetwork, epsilon=0.05, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) if args.upload_model: from cleanrl_utils.huggingface import push_to_hub repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name push_to_hub(args, episodic_returns, repo_id, "DQN", f"runs/{run_name}", f"videos/{run_name}-eval") envs.close() writer.close()