import gym import stable_baselines3 from stable_baselines3 import A2C import matplotlib.pyplot as plt # Initialize the environment env = gym.make("CartPole-v1") # Define the A2C model and learn model = A2C("MlpPolicy", env, verbose=1) model.learn(total_timesteps=10000) # Save the model model.save("a2c_cartpole_v1") rewards_by_episodes = [] cum_reward = 0 # Test the trained model and output the reward obs = env.reset() for i in range(2000): action, _states = model.predict(obs) obs, rewards, dones, info = env.step(action) env.render() if rewards == 1.0: cum_reward += 1 if dones: rewards_by_episodes.append(cum_reward) env.reset() cum_reward = 0 env.close() # define x axis for plot : x = list(range(len(rewards_by_episodes))) # Plot rewards by episodes plt.figure("Figure 1") plt.xlabel("Episodes") plt.ylabel("Reward") plt.plot(x, rewards_by_episodes) plt.savefig("a2c_cartpole_v1.png") plt.title("Rewards by episodes for SB3-A2C algorithm") plt.show()