|
import gym |
|
import stable_baselines3 |
|
from stable_baselines3 import A2C |
|
import matplotlib.pyplot as plt |
|
|
|
|
|
env = gym.make("CartPole-v1") |
|
|
|
|
|
model = A2C("MlpPolicy", env, verbose=1) |
|
model.learn(total_timesteps=10000) |
|
|
|
|
|
model.save("a2c_cartpole_v1") |
|
|
|
rewards_by_episodes = [] |
|
cum_reward = 0 |
|
|
|
|
|
obs = env.reset() |
|
for i in range(2000): |
|
action, _states = model.predict(obs) |
|
obs, rewards, dones, info = env.step(action) |
|
env.render() |
|
if rewards == 1.0: |
|
cum_reward += 1 |
|
if dones: |
|
rewards_by_episodes.append(cum_reward) |
|
env.reset() |
|
cum_reward = 0 |
|
|
|
env.close() |
|
|
|
|
|
x = list(range(len(rewards_by_episodes))) |
|
|
|
|
|
plt.figure("Figure 1") |
|
plt.xlabel("Episodes") |
|
plt.ylabel("Reward") |
|
plt.plot(x, rewards_by_episodes) |
|
plt.savefig("a2c_cartpole_v1.png") |
|
plt.title("Rewards by episodes for SB3-A2C algorithm") |
|
plt.show() |
|
|