jackvial commited on
Commit
0484cb0
0 Parent(s):
Files changed (6) hide show
  1. .gitignore +3 -0
  2. .vscode/launch.json +15 -0
  3. agent.py +30 -0
  4. environment.py +179 -0
  5. main.py +157 -0
  6. requirements.txt +6 -0
.gitignore ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ .env
2
+ frozen_lake_env
3
+ __pycache__
.vscode/launch.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "version": "0.2.0",
3
+ "configurations": [
4
+ {
5
+ "name": "Debug",
6
+ "type": "python",
7
+ "request": "launch",
8
+ "python": "${workspaceFolder}/frozen_lake_env/bin/python3.10",
9
+ "program": "${workspaceFolder}/main.py",
10
+ "envFile": "${workspaceFolder}/.env",
11
+ "console": "integratedTerminal",
12
+ "justMyCode": false
13
+ }
14
+ ]
15
+ }
agent.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ class QLearningAgent:
5
+ def __init__(self, env) -> None:
6
+ self.env = env
7
+ self.q_table = self.build_q_table(env.observation_space.n, env.action_space.n)
8
+
9
+ def build_q_table(self, n_states, n_actions):
10
+ return np.zeros((n_states, n_actions))
11
+
12
+ def epsilon_greedy_policy(self, state, epsilon):
13
+
14
+ # Epsilon probability of taking a random action or the
15
+ # action that has the highest Q value for the current state
16
+ if np.random.random() < epsilon:
17
+ return np.random.choice(self.env.action_space.n)
18
+ return np.argmax(self.q_table[state])
19
+
20
+ def greedy_policy(self, state):
21
+ return np.argmax(self.q_table[state])
22
+
23
+ def update_q_table(self, state, action, reward, gamma, learning_rate, new_state):
24
+
25
+ # Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]
26
+ current_q = self.q_table[state][action]
27
+ next_max_q = np.max(self.q_table[new_state])
28
+ self.q_table[state][action] = current_q + learning_rate * (
29
+ reward + gamma * next_max_q - current_q
30
+ )
environment.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import curses
3
+ import numpy as np
4
+ import collections
5
+ import warnings
6
+ from typing import Optional
7
+ from gym.envs.toy_text.frozen_lake import FrozenLakeEnv
8
+
9
+ warnings.filterwarnings("ignore")
10
+
11
+
12
+ class FrozenLakeEnvCustom(FrozenLakeEnv):
13
+ def __init__(
14
+ self,
15
+ render_mode: Optional[str] = None,
16
+ desc=None,
17
+ map_name="4x4",
18
+ is_slippery=True,
19
+ ):
20
+ self.curses_screen = curses.initscr()
21
+ curses.start_color()
22
+ curses.curs_set(0)
23
+ self.curses_color_pairs = self.build_ncurses_color_pairs()
24
+
25
+ # Blocking reads
26
+ self.curses_screen.timeout(-1)
27
+
28
+ super().__init__(
29
+ render_mode=render_mode,
30
+ desc=desc,
31
+ map_name=map_name,
32
+ is_slippery=is_slippery,
33
+ )
34
+
35
+ def build_ncurses_color_pairs(self):
36
+ """
37
+ Based on Deepmind Pycolab https://github.com/deepmind/pycolab/blob/master/pycolab/human_ui.py
38
+ """
39
+
40
+ color_fg = {
41
+ " ": (0, 0, 0),
42
+ "S": (368, 333, 388),
43
+ "H": (309, 572, 999),
44
+ "P": (999, 364, 0),
45
+ "F": (500, 999, 948),
46
+ "G": (999, 917, 298),
47
+ "?": (368, 333, 388),
48
+ "←": (309, 572, 999),
49
+ "↓": (999, 364, 0),
50
+ "→": (500, 999, 948),
51
+ "↑": (999, 917, 298),
52
+ }
53
+
54
+ color_pair = {}
55
+
56
+ cpair_0_fg_id, cpair_0_bg_id = curses.pair_content(0)
57
+ ids = set(range(curses.COLORS - 1)) - {
58
+ cpair_0_fg_id,
59
+ cpair_0_bg_id,
60
+ }
61
+
62
+ # We use color IDs from large to small.
63
+ ids = list(reversed(sorted(ids)))
64
+
65
+ # But only those color IDs we actually need.
66
+ ids = ids[: len(color_fg)]
67
+ color_ids = dict(zip(color_fg.values(), ids))
68
+
69
+ # Program these colors into curses.
70
+ for color, cid in color_ids.items():
71
+ curses.init_color(cid, *color)
72
+
73
+ # Now add the default colors to the color-to-ID map.
74
+ cpair_0_fg = curses.color_content(cpair_0_fg_id)
75
+ cpair_0_bg = curses.color_content(cpair_0_bg_id)
76
+ color_ids[cpair_0_fg] = cpair_0_fg_id
77
+ color_ids[cpair_0_bg] = cpair_0_bg_id
78
+
79
+ # The color pair IDs we'll use for all characters count up from 1; note that
80
+ # the "default" color pair of 0 is already defined, since _color_pair is a
81
+ # defaultdict.
82
+ color_pair.update(
83
+ {character: pid for pid, character in enumerate(color_fg, start=1)}
84
+ )
85
+
86
+ # Program these color pairs into curses, and that's all there is to do.
87
+ for character, pid in color_pair.items():
88
+
89
+ # Get foreground and background colors for this character. Note how in
90
+ # the absence of a specified background color, the same color as the
91
+ # foreground is used.
92
+ cpair_fg = color_fg.get(character, cpair_0_fg_id)
93
+ cpair_bg = color_fg.get(character, cpair_0_fg_id)
94
+
95
+ # Get color IDs for those colors and initialise a color pair.
96
+ cpair_fg_id = color_ids[cpair_fg]
97
+ cpair_bg_id = color_ids[cpair_bg]
98
+ curses.init_pair(pid, cpair_fg_id, cpair_bg_id)
99
+
100
+ return color_pair
101
+
102
+ def render_ncurses_ui(self, screen, board, color_pair, title, q_table):
103
+ screen.erase()
104
+
105
+ # Draw the title
106
+ screen.addstr(0, 2, title)
107
+
108
+ # Draw the game board
109
+ for row_index, board_line in enumerate(board, start=1):
110
+ screen.move(row_index, 2)
111
+ for codepoint in "".join(list(board_line)):
112
+ screen.addch(codepoint, curses.color_pair(color_pair[codepoint]))
113
+
114
+ def action_to_char(action):
115
+ if action == 0:
116
+ return "←"
117
+ elif action == 1:
118
+ return "↓"
119
+ elif action == 2:
120
+ return "→"
121
+ elif action == 3:
122
+ return "↑"
123
+ else:
124
+ return "?"
125
+
126
+ # Draw the action grid
127
+ max_action_table = np.argmax(q_table, axis=1).reshape(4, 4)
128
+ for row_index, row in enumerate(max_action_table, start=1):
129
+ screen.move(row_index, 8)
130
+ for action in row:
131
+ char = action_to_char(action)
132
+ screen.addch(char, curses.color_pair(color_pair[char]))
133
+
134
+ # Draw the Q-table
135
+ q_table_2d = q_table.reshape(4, 16)
136
+ for row_index, row in enumerate(q_table_2d, start=1):
137
+ screen.move(row_index, 14)
138
+ for col_index, col in enumerate(row):
139
+ action = col_index % 4
140
+ char = action_to_char(action)
141
+ screen.addstr(f" {col:.2f}", curses.color_pair(color_pair[char]))
142
+ if action == 3:
143
+ screen.addstr(" ", curses.color_pair(color_pair[" "]))
144
+
145
+ # Redraw the game screen (but in the curses memory buffer only).
146
+ screen.noutrefresh()
147
+
148
+ def ansi_frame_to_board(self, frame_string):
149
+ parts = frame_string.split("\n")
150
+ board = []
151
+ p = "\x1b[41m"
152
+ for part in parts[1:]:
153
+ if len(part):
154
+ row = re.findall(r"S|F|H|G", part)
155
+ try:
156
+ row[part.index(p)] = "P"
157
+ except:
158
+ pass
159
+ board.append(row)
160
+
161
+ return np.array(board)
162
+
163
+ def render(self, title=None, q_table=None):
164
+ if self.render_mode == "curses":
165
+ frame = self._render_text()
166
+
167
+ board = self.ansi_frame_to_board(frame)
168
+ self.render_ncurses_ui(
169
+ self.curses_screen, board, self.curses_color_pairs, title, q_table
170
+ )
171
+
172
+ # Show the screen to the user.
173
+ curses.doupdate()
174
+ return board
175
+
176
+ return super().render()
177
+
178
+ def get_expected_new_state_for_action(self, action):
179
+ return self.P[self.s][action][1][1]
main.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import curses
3
+ import numpy as np
4
+ import warnings
5
+ from environment import FrozenLakeEnvCustom
6
+ from agent import QLearningAgent
7
+
8
+ warnings.filterwarnings("ignore")
9
+
10
+
11
+ def train_agent(
12
+ n_training_episodes,
13
+ min_epsilon,
14
+ max_epsilon,
15
+ decay_rate,
16
+ env,
17
+ max_steps,
18
+ agent,
19
+ learning_rate,
20
+ gamma,
21
+ use_frame_delay,
22
+ ):
23
+ for episode in range(n_training_episodes + 1):
24
+
25
+ # Reduce epsilon (because we need less and less exploration)
26
+ epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
27
+ -decay_rate * episode
28
+ )
29
+ state, info = env.reset()
30
+ done = False
31
+ for step in range(max_steps):
32
+
33
+ # Choose the action At using epsilon greedy policy
34
+ action = agent.epsilon_greedy_policy(state, epsilon)
35
+
36
+ # Take action At and observe Rt+1 and St+1
37
+ # Take the action (a) and observe the outcome state(s') and reward (r)
38
+ new_state, reward, done, truncated, info = env.step(action)
39
+ agent.update_q_table(state, action, reward, gamma, learning_rate, new_state)
40
+
41
+ env.render(
42
+ title=f"Training: {episode}/{n_training_episodes}",
43
+ q_table=agent.q_table,
44
+ )
45
+
46
+ if use_frame_delay:
47
+ time.sleep(0.01)
48
+
49
+ if done:
50
+ break
51
+
52
+ state = new_state
53
+ return agent
54
+
55
+
56
+ def evaluate_agent(env, max_steps, n_eval_episodes, agent, seed, use_frame_delay):
57
+ successful_episodes = []
58
+ episodes_slips = []
59
+ for episode in range(n_eval_episodes + 1):
60
+ if seed:
61
+ state, info = env.reset(seed=seed[episode])
62
+ else:
63
+ state, info = env.reset()
64
+ done = False
65
+ total_rewards_ep = 0
66
+
67
+ slips = []
68
+ for step in range(max_steps):
69
+
70
+ # Take the action (index) that have the maximum expected future reward given that state
71
+ action = agent.greedy_policy(state)
72
+
73
+ expected_new_state = env.get_expected_new_state_for_action(action)
74
+ new_state, reward, done, truncated, info = env.step(action)
75
+ total_rewards_ep += reward
76
+
77
+ if expected_new_state != new_state:
78
+ slips.append((step, action, expected_new_state, new_state))
79
+
80
+ if reward != 0:
81
+ successful_episodes.append(episode)
82
+
83
+ env.render(
84
+ title=f"Evaluating: {episode}/{n_eval_episodes} | Slips: {len(slips)}",
85
+ q_table=agent.q_table,
86
+ )
87
+ episodes_slips.append(len(slips))
88
+
89
+ if use_frame_delay:
90
+ time.sleep(0.01)
91
+
92
+ if done:
93
+ break
94
+ state = new_state
95
+
96
+ mean_slips = np.mean(episodes_slips)
97
+ return successful_episodes, mean_slips
98
+
99
+
100
+ def main(screen):
101
+
102
+ # Training parameters
103
+ n_training_episodes = 2000 # Total training episodes
104
+ learning_rate = 0.1 # Learning rate
105
+
106
+ # Evaluation parameters
107
+ n_eval_episodes = 100 # Total number of test episodes
108
+
109
+ # Environment parameters
110
+ max_steps = 99 # Max steps per episode
111
+ gamma = 0.99 # Discounting rate
112
+ eval_seed = [] # The evaluation seed of the environment
113
+
114
+ # Exploration parameters
115
+ max_epsilon = 1.0 # Exploration probability at start
116
+ min_epsilon = 0.05 # Minimum exploration probability
117
+ decay_rate = 0.0005 # Exponential decay rate for exploration prob
118
+
119
+ use_frame_delay = False
120
+
121
+ env = FrozenLakeEnvCustom(map_name="4x4", is_slippery=True, render_mode="curses")
122
+
123
+ agent = QLearningAgent(env)
124
+ agent = train_agent(
125
+ n_training_episodes,
126
+ min_epsilon,
127
+ max_epsilon,
128
+ decay_rate,
129
+ env,
130
+ max_steps,
131
+ agent,
132
+ learning_rate,
133
+ gamma,
134
+ use_frame_delay,
135
+ )
136
+
137
+ successful_episodes, mean_slips = evaluate_agent(
138
+ env, max_steps, n_eval_episodes, agent, eval_seed, use_frame_delay
139
+ )
140
+
141
+ env_curses_screen = env.curses_screen
142
+ env_curses_screen.addstr(
143
+ 5,
144
+ 2,
145
+ f"Successful episodes: {len(successful_episodes)}/{n_eval_episodes} | Avg slips: {mean_slips:.2f}\n\n",
146
+ )
147
+ env_curses_screen.noutrefresh()
148
+ curses.doupdate()
149
+ time.sleep(10)
150
+
151
+
152
+ if __name__ == "__main__":
153
+
154
+ # Reset the terminal state after using curses
155
+ # Call main("") instead if you want to leave the final state of the environment
156
+ # on the terminal
157
+ curses.wrapper(main)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ cloudpickle==2.2.0
2
+ gym==0.26.2
3
+ gym-notices==0.0.8
4
+ numpy==1.24.0
5
+ pygame==2.1.0
6
+ tqdm==4.64.1