cartpole-v1 / a2c_sb3_cartpole.py
colasa's picture
Upload 3 files
121f68f
raw
history blame contribute delete
No virus
1.02 kB
import gym
import stable_baselines3
from stable_baselines3 import A2C
import matplotlib.pyplot as plt
# Initialize the environment
env = gym.make("CartPole-v1")
# Define the A2C model and learn
model = A2C("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000)
# Save the model
model.save("a2c_cartpole_v1")
rewards_by_episodes = []
cum_reward = 0
# Test the trained model and output the reward
obs = env.reset()
for i in range(2000):
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()
if rewards == 1.0:
cum_reward += 1
if dones:
rewards_by_episodes.append(cum_reward)
env.reset()
cum_reward = 0
env.close()
# define x axis for plot :
x = list(range(len(rewards_by_episodes)))
# Plot rewards by episodes
plt.figure("Figure 1")
plt.xlabel("Episodes")
plt.ylabel("Reward")
plt.plot(x, rewards_by_episodes)
plt.savefig("a2c_cartpole_v1.png")
plt.title("Rewards by episodes for SB3-A2C algorithm")
plt.show()