# docs and experiment results can be found at https://docs.cleanrl.dev/rl-algorithms/td3/#td3_continuous_action_jaxpy import argparse import os import random import time from distutils.util import strtobool import flax import flax.linen as nn import gymnasium as gym import jax import jax.numpy as jnp import numpy as np import optax from flax.training.train_state import TrainState from stable_baselines3.common.buffers import ReplayBuffer from torch.utils.tensorboard import SummaryWriter def parse_args(): # fmt: off parser = argparse.ArgumentParser() parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), help="the name of this experiment") parser.add_argument("--seed", type=int, default=1, help="seed of the experiment") parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="if toggled, this experiment will be tracked with Weights and Biases") parser.add_argument("--wandb-project-name", type=str, default="cleanRL", help="the wandb's project name") parser.add_argument("--wandb-entity", type=str, default=None, help="the entity (team) of wandb's project") parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to capture videos of the agent performances (check out `videos` folder)") parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to save model into the `runs/{run_name}` folder") parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, help="whether to upload the saved model to huggingface") parser.add_argument("--hf-entity", type=str, default="", help="the user or org name of the model repository from the Hugging Face Hub") # Algorithm specific arguments parser.add_argument("--env-id", type=str, default="HalfCheetah-v4", help="the id of the environment") parser.add_argument("--total-timesteps", type=int, default=1000000, help="total timesteps of the experiments") parser.add_argument("--learning-rate", type=float, default=3e-4, help="the learning rate of the optimizer") parser.add_argument("--buffer-size", type=int, default=int(1e6), help="the replay memory buffer size") parser.add_argument("--gamma", type=float, default=0.99, help="the discount factor gamma") parser.add_argument("--tau", type=float, default=0.005, help="target smoothing coefficient (default: 0.005)") parser.add_argument("--policy-noise", type=float, default=0.2, help="the scale of policy noise") parser.add_argument("--batch-size", type=int, default=256, help="the batch size of sample from the reply memory") parser.add_argument("--exploration-noise", type=float, default=0.1, help="the scale of exploration noise") parser.add_argument("--learning-starts", type=int, default=25e3, help="timestep to start learning") parser.add_argument("--policy-frequency", type=int, default=2, help="the frequency of training policy (delayed)") parser.add_argument("--noise-clip", type=float, default=0.5, help="noise clip parameter of the Target Policy Smoothing Regularization") args = parser.parse_args() # fmt: on return args def make_env(env_id, seed, idx, capture_video, run_name): def thunk(): if capture_video and idx == 0: env = gym.make(env_id, render_mode="rgb_array") env = gym.wrappers.RecordVideo(env, f"videos/{run_name}") else: env = gym.make(env_id) env = gym.wrappers.RecordEpisodeStatistics(env) env.action_space.seed(seed) return env return thunk # ALGO LOGIC: initialize agent here: class QNetwork(nn.Module): @nn.compact def __call__(self, x: jnp.ndarray, a: jnp.ndarray): x = jnp.concatenate([x, a], -1) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(1)(x) return x class Actor(nn.Module): action_dim: int action_scale: jnp.ndarray action_bias: jnp.ndarray @nn.compact def __call__(self, x): x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(256)(x) x = nn.relu(x) x = nn.Dense(self.action_dim)(x) x = nn.tanh(x) x = x * self.action_scale + self.action_bias return x class TrainState(TrainState): target_params: flax.core.FrozenDict if __name__ == "__main__": import stable_baselines3 as sb3 if sb3.__version__ < "2.0": raise ValueError( """Ongoing migration: run the following command to install the new dependencies: poetry run pip install "stable_baselines3==2.0.0a1" """ ) args = parse_args() run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{int(time.time())}" if args.track: import wandb wandb.init( project=args.wandb_project_name, entity=args.wandb_entity, sync_tensorboard=True, config=vars(args), name=run_name, monitor_gym=True, save_code=True, ) writer = SummaryWriter(f"runs/{run_name}") writer.add_text( "hyperparameters", "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), ) video_filenames = set() # TRY NOT TO MODIFY: seeding random.seed(args.seed) np.random.seed(args.seed) key = jax.random.PRNGKey(args.seed) key, actor_key, qf1_key, qf2_key = jax.random.split(key, 4) # env setup envs = gym.vector.SyncVectorEnv([make_env(args.env_id, args.seed, 0, args.capture_video, run_name)]) assert isinstance(envs.single_action_space, gym.spaces.Box), "only continuous action space is supported" max_action = float(envs.single_action_space.high[0]) envs.single_observation_space.dtype = np.float32 rb = ReplayBuffer( args.buffer_size, envs.single_observation_space, envs.single_action_space, device="cpu", handle_timeout_termination=False, ) # TRY NOT TO MODIFY: start the game obs, _ = envs.reset(seed=args.seed) actor = Actor( action_dim=np.prod(envs.single_action_space.shape), action_scale=jnp.array((envs.action_space.high - envs.action_space.low) / 2.0), action_bias=jnp.array((envs.action_space.high + envs.action_space.low) / 2.0), ) actor_state = TrainState.create( apply_fn=actor.apply, params=actor.init(actor_key, obs), target_params=actor.init(actor_key, obs), tx=optax.adam(learning_rate=args.learning_rate), ) qf = QNetwork() qf1_state = TrainState.create( apply_fn=qf.apply, params=qf.init(qf1_key, obs, envs.action_space.sample()), target_params=qf.init(qf1_key, obs, envs.action_space.sample()), tx=optax.adam(learning_rate=args.learning_rate), ) qf2_state = TrainState.create( apply_fn=qf.apply, params=qf.init(qf2_key, obs, envs.action_space.sample()), target_params=qf.init(qf2_key, obs, envs.action_space.sample()), tx=optax.adam(learning_rate=args.learning_rate), ) actor.apply = jax.jit(actor.apply) qf.apply = jax.jit(qf.apply) @jax.jit def update_critic( actor_state: TrainState, qf1_state: TrainState, qf2_state: TrainState, observations: np.ndarray, actions: np.ndarray, next_observations: np.ndarray, rewards: np.ndarray, terminations: np.ndarray, key: jnp.ndarray, ): # TODO Maybe pre-generate a lot of random keys # also check https://jax.readthedocs.io/en/latest/jax.random.html key, noise_key = jax.random.split(key, 2) clipped_noise = ( jnp.clip( (jax.random.normal(noise_key, actions.shape) * args.policy_noise), -args.noise_clip, args.noise_clip, ) * actor.action_scale ) next_state_actions = jnp.clip( actor.apply(actor_state.target_params, next_observations) + clipped_noise, envs.single_action_space.low, envs.single_action_space.high, ) qf1_next_target = qf.apply(qf1_state.target_params, next_observations, next_state_actions).reshape(-1) qf2_next_target = qf.apply(qf2_state.target_params, next_observations, next_state_actions).reshape(-1) min_qf_next_target = jnp.minimum(qf1_next_target, qf2_next_target) next_q_value = (rewards + (1 - terminations) * args.gamma * (min_qf_next_target)).reshape(-1) def mse_loss(params): qf_a_values = qf.apply(params, observations, actions).squeeze() return ((qf_a_values - next_q_value) ** 2).mean(), qf_a_values.mean() (qf1_loss_value, qf1_a_values), grads1 = jax.value_and_grad(mse_loss, has_aux=True)(qf1_state.params) (qf2_loss_value, qf2_a_values), grads2 = jax.value_and_grad(mse_loss, has_aux=True)(qf2_state.params) qf1_state = qf1_state.apply_gradients(grads=grads1) qf2_state = qf2_state.apply_gradients(grads=grads2) return (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key @jax.jit def update_actor( actor_state: TrainState, qf1_state: TrainState, qf2_state: TrainState, observations: np.ndarray, ): def actor_loss(params): return -qf.apply(qf1_state.params, observations, actor.apply(params, observations)).mean() actor_loss_value, grads = jax.value_and_grad(actor_loss)(actor_state.params) actor_state = actor_state.apply_gradients(grads=grads) actor_state = actor_state.replace( target_params=optax.incremental_update(actor_state.params, actor_state.target_params, args.tau) ) qf1_state = qf1_state.replace( target_params=optax.incremental_update(qf1_state.params, qf1_state.target_params, args.tau) ) qf2_state = qf2_state.replace( target_params=optax.incremental_update(qf2_state.params, qf2_state.target_params, args.tau) ) return actor_state, (qf1_state, qf2_state), actor_loss_value start_time = time.time() for global_step in range(args.total_timesteps): # ALGO LOGIC: put action logic here if global_step < args.learning_starts: actions = np.array([envs.single_action_space.sample() for _ in range(envs.num_envs)]) else: actions = actor.apply(actor_state.params, obs) actions = np.array( [ ( jax.device_get(actions)[0] + np.random.normal(0, max_action * args.exploration_noise, size=envs.single_action_space.shape) ).clip(envs.single_action_space.low, envs.single_action_space.high) ] ) # TRY NOT TO MODIFY: execute the game and log data. next_obs, rewards, terminations, truncations, infos = envs.step(actions) # TRY NOT TO MODIFY: record rewards for plotting purposes if "final_info" in infos: for info in infos["final_info"]: print(f"global_step={global_step}, episodic_return={info['episode']['r']}") writer.add_scalar("charts/episodic_return", info["episode"]["r"], global_step) writer.add_scalar("charts/episodic_length", info["episode"]["l"], global_step) break # TRY NOT TO MODIFY: save data to replay buffer; handle `terminal_observation` real_next_obs = next_obs.copy() for idx, trunc in enumerate(truncations): if trunc: real_next_obs[idx] = infos["final_observation"][idx] rb.add(obs, real_next_obs, actions, rewards, terminations, infos) # TRY NOT TO MODIFY: CRUCIAL step easy to overlook obs = next_obs # ALGO LOGIC: training. if global_step > args.learning_starts: data = rb.sample(args.batch_size) (qf1_state, qf2_state), (qf1_loss_value, qf2_loss_value), (qf1_a_values, qf2_a_values), key = update_critic( actor_state, qf1_state, qf2_state, data.observations.numpy(), data.actions.numpy(), data.next_observations.numpy(), data.rewards.flatten().numpy(), data.dones.flatten().numpy(), key, ) if global_step % args.policy_frequency == 0: actor_state, (qf1_state, qf2_state), actor_loss_value = update_actor( actor_state, qf1_state, qf2_state, data.observations.numpy(), ) if global_step % 100 == 0: writer.add_scalar("losses/qf1_loss", qf1_loss_value.item(), global_step) writer.add_scalar("losses/qf2_loss", qf2_loss_value.item(), global_step) writer.add_scalar("losses/qf1_values", qf1_a_values.item(), global_step) writer.add_scalar("losses/qf2_values", qf2_a_values.item(), global_step) writer.add_scalar("losses/actor_loss", actor_loss_value.item(), global_step) print("SPS:", int(global_step / (time.time() - start_time))) writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) if args.save_model: model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" with open(model_path, "wb") as f: f.write( flax.serialization.to_bytes( [ actor_state.params, qf1_state.params, qf2_state.params, ] ) ) print(f"model saved to {model_path}") from cleanrl_utils.evals.td3_jax_eval import evaluate episodic_returns = evaluate( model_path, make_env, args.env_id, eval_episodes=10, run_name=f"{run_name}-eval", Model=(Actor, QNetwork), exploration_noise=args.exploration_noise, ) for idx, episodic_return in enumerate(episodic_returns): writer.add_scalar("eval/episodic_return", episodic_return, idx) if args.upload_model: from cleanrl_utils.huggingface import push_to_hub repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name push_to_hub(args, episodic_returns, repo_id, "TD3", f"runs/{run_name}", f"videos/{run_name}-eval") envs.close() writer.close()