colasa commited on
Commit
121f68f
1 Parent(s): 6068323

Upload 3 files

Browse files
a2c_cartpole_v1.png ADDED
a2c_cartpole_v1.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2dee6997223d845912eed3add6a2874bb689705c87fe9819b0954bcb4791a405
3
+ size 93582
a2c_sb3_cartpole.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ import stable_baselines3
3
+ from stable_baselines3 import A2C
4
+ import matplotlib.pyplot as plt
5
+
6
+ # Initialize the environment
7
+ env = gym.make("CartPole-v1")
8
+
9
+ # Define the A2C model and learn
10
+ model = A2C("MlpPolicy", env, verbose=1)
11
+ model.learn(total_timesteps=10000)
12
+
13
+ # Save the model
14
+ model.save("a2c_cartpole_v1")
15
+
16
+ rewards_by_episodes = []
17
+ cum_reward = 0
18
+
19
+ # Test the trained model and output the reward
20
+ obs = env.reset()
21
+ for i in range(2000):
22
+ action, _states = model.predict(obs)
23
+ obs, rewards, dones, info = env.step(action)
24
+ env.render()
25
+ if rewards == 1.0:
26
+ cum_reward += 1
27
+ if dones:
28
+ rewards_by_episodes.append(cum_reward)
29
+ env.reset()
30
+ cum_reward = 0
31
+
32
+ env.close()
33
+
34
+ # define x axis for plot :
35
+ x = list(range(len(rewards_by_episodes)))
36
+
37
+ # Plot rewards by episodes
38
+ plt.figure("Figure 1")
39
+ plt.xlabel("Episodes")
40
+ plt.ylabel("Reward")
41
+ plt.plot(x, rewards_by_episodes)
42
+ plt.savefig("a2c_cartpole_v1.png")
43
+ plt.title("Rewards by episodes for SB3-A2C algorithm")
44
+ plt.show()