Spaces:
Running
Running
jackvial
commited on
Commit
•
0484cb0
0
Parent(s):
setup
Browse files- .gitignore +3 -0
- .vscode/launch.json +15 -0
- agent.py +30 -0
- environment.py +179 -0
- main.py +157 -0
- requirements.txt +6 -0
.gitignore
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
.env
|
2 |
+
frozen_lake_env
|
3 |
+
__pycache__
|
.vscode/launch.json
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"version": "0.2.0",
|
3 |
+
"configurations": [
|
4 |
+
{
|
5 |
+
"name": "Debug",
|
6 |
+
"type": "python",
|
7 |
+
"request": "launch",
|
8 |
+
"python": "${workspaceFolder}/frozen_lake_env/bin/python3.10",
|
9 |
+
"program": "${workspaceFolder}/main.py",
|
10 |
+
"envFile": "${workspaceFolder}/.env",
|
11 |
+
"console": "integratedTerminal",
|
12 |
+
"justMyCode": false
|
13 |
+
}
|
14 |
+
]
|
15 |
+
}
|
agent.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
|
4 |
+
class QLearningAgent:
|
5 |
+
def __init__(self, env) -> None:
|
6 |
+
self.env = env
|
7 |
+
self.q_table = self.build_q_table(env.observation_space.n, env.action_space.n)
|
8 |
+
|
9 |
+
def build_q_table(self, n_states, n_actions):
|
10 |
+
return np.zeros((n_states, n_actions))
|
11 |
+
|
12 |
+
def epsilon_greedy_policy(self, state, epsilon):
|
13 |
+
|
14 |
+
# Epsilon probability of taking a random action or the
|
15 |
+
# action that has the highest Q value for the current state
|
16 |
+
if np.random.random() < epsilon:
|
17 |
+
return np.random.choice(self.env.action_space.n)
|
18 |
+
return np.argmax(self.q_table[state])
|
19 |
+
|
20 |
+
def greedy_policy(self, state):
|
21 |
+
return np.argmax(self.q_table[state])
|
22 |
+
|
23 |
+
def update_q_table(self, state, action, reward, gamma, learning_rate, new_state):
|
24 |
+
|
25 |
+
# Update Q(s,a):= Q(s,a) + lr [R(s,a) + gamma * max Q(s',a') - Q(s,a)]
|
26 |
+
current_q = self.q_table[state][action]
|
27 |
+
next_max_q = np.max(self.q_table[new_state])
|
28 |
+
self.q_table[state][action] = current_q + learning_rate * (
|
29 |
+
reward + gamma * next_max_q - current_q
|
30 |
+
)
|
environment.py
ADDED
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import curses
|
3 |
+
import numpy as np
|
4 |
+
import collections
|
5 |
+
import warnings
|
6 |
+
from typing import Optional
|
7 |
+
from gym.envs.toy_text.frozen_lake import FrozenLakeEnv
|
8 |
+
|
9 |
+
warnings.filterwarnings("ignore")
|
10 |
+
|
11 |
+
|
12 |
+
class FrozenLakeEnvCustom(FrozenLakeEnv):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
render_mode: Optional[str] = None,
|
16 |
+
desc=None,
|
17 |
+
map_name="4x4",
|
18 |
+
is_slippery=True,
|
19 |
+
):
|
20 |
+
self.curses_screen = curses.initscr()
|
21 |
+
curses.start_color()
|
22 |
+
curses.curs_set(0)
|
23 |
+
self.curses_color_pairs = self.build_ncurses_color_pairs()
|
24 |
+
|
25 |
+
# Blocking reads
|
26 |
+
self.curses_screen.timeout(-1)
|
27 |
+
|
28 |
+
super().__init__(
|
29 |
+
render_mode=render_mode,
|
30 |
+
desc=desc,
|
31 |
+
map_name=map_name,
|
32 |
+
is_slippery=is_slippery,
|
33 |
+
)
|
34 |
+
|
35 |
+
def build_ncurses_color_pairs(self):
|
36 |
+
"""
|
37 |
+
Based on Deepmind Pycolab https://github.com/deepmind/pycolab/blob/master/pycolab/human_ui.py
|
38 |
+
"""
|
39 |
+
|
40 |
+
color_fg = {
|
41 |
+
" ": (0, 0, 0),
|
42 |
+
"S": (368, 333, 388),
|
43 |
+
"H": (309, 572, 999),
|
44 |
+
"P": (999, 364, 0),
|
45 |
+
"F": (500, 999, 948),
|
46 |
+
"G": (999, 917, 298),
|
47 |
+
"?": (368, 333, 388),
|
48 |
+
"←": (309, 572, 999),
|
49 |
+
"↓": (999, 364, 0),
|
50 |
+
"→": (500, 999, 948),
|
51 |
+
"↑": (999, 917, 298),
|
52 |
+
}
|
53 |
+
|
54 |
+
color_pair = {}
|
55 |
+
|
56 |
+
cpair_0_fg_id, cpair_0_bg_id = curses.pair_content(0)
|
57 |
+
ids = set(range(curses.COLORS - 1)) - {
|
58 |
+
cpair_0_fg_id,
|
59 |
+
cpair_0_bg_id,
|
60 |
+
}
|
61 |
+
|
62 |
+
# We use color IDs from large to small.
|
63 |
+
ids = list(reversed(sorted(ids)))
|
64 |
+
|
65 |
+
# But only those color IDs we actually need.
|
66 |
+
ids = ids[: len(color_fg)]
|
67 |
+
color_ids = dict(zip(color_fg.values(), ids))
|
68 |
+
|
69 |
+
# Program these colors into curses.
|
70 |
+
for color, cid in color_ids.items():
|
71 |
+
curses.init_color(cid, *color)
|
72 |
+
|
73 |
+
# Now add the default colors to the color-to-ID map.
|
74 |
+
cpair_0_fg = curses.color_content(cpair_0_fg_id)
|
75 |
+
cpair_0_bg = curses.color_content(cpair_0_bg_id)
|
76 |
+
color_ids[cpair_0_fg] = cpair_0_fg_id
|
77 |
+
color_ids[cpair_0_bg] = cpair_0_bg_id
|
78 |
+
|
79 |
+
# The color pair IDs we'll use for all characters count up from 1; note that
|
80 |
+
# the "default" color pair of 0 is already defined, since _color_pair is a
|
81 |
+
# defaultdict.
|
82 |
+
color_pair.update(
|
83 |
+
{character: pid for pid, character in enumerate(color_fg, start=1)}
|
84 |
+
)
|
85 |
+
|
86 |
+
# Program these color pairs into curses, and that's all there is to do.
|
87 |
+
for character, pid in color_pair.items():
|
88 |
+
|
89 |
+
# Get foreground and background colors for this character. Note how in
|
90 |
+
# the absence of a specified background color, the same color as the
|
91 |
+
# foreground is used.
|
92 |
+
cpair_fg = color_fg.get(character, cpair_0_fg_id)
|
93 |
+
cpair_bg = color_fg.get(character, cpair_0_fg_id)
|
94 |
+
|
95 |
+
# Get color IDs for those colors and initialise a color pair.
|
96 |
+
cpair_fg_id = color_ids[cpair_fg]
|
97 |
+
cpair_bg_id = color_ids[cpair_bg]
|
98 |
+
curses.init_pair(pid, cpair_fg_id, cpair_bg_id)
|
99 |
+
|
100 |
+
return color_pair
|
101 |
+
|
102 |
+
def render_ncurses_ui(self, screen, board, color_pair, title, q_table):
|
103 |
+
screen.erase()
|
104 |
+
|
105 |
+
# Draw the title
|
106 |
+
screen.addstr(0, 2, title)
|
107 |
+
|
108 |
+
# Draw the game board
|
109 |
+
for row_index, board_line in enumerate(board, start=1):
|
110 |
+
screen.move(row_index, 2)
|
111 |
+
for codepoint in "".join(list(board_line)):
|
112 |
+
screen.addch(codepoint, curses.color_pair(color_pair[codepoint]))
|
113 |
+
|
114 |
+
def action_to_char(action):
|
115 |
+
if action == 0:
|
116 |
+
return "←"
|
117 |
+
elif action == 1:
|
118 |
+
return "↓"
|
119 |
+
elif action == 2:
|
120 |
+
return "→"
|
121 |
+
elif action == 3:
|
122 |
+
return "↑"
|
123 |
+
else:
|
124 |
+
return "?"
|
125 |
+
|
126 |
+
# Draw the action grid
|
127 |
+
max_action_table = np.argmax(q_table, axis=1).reshape(4, 4)
|
128 |
+
for row_index, row in enumerate(max_action_table, start=1):
|
129 |
+
screen.move(row_index, 8)
|
130 |
+
for action in row:
|
131 |
+
char = action_to_char(action)
|
132 |
+
screen.addch(char, curses.color_pair(color_pair[char]))
|
133 |
+
|
134 |
+
# Draw the Q-table
|
135 |
+
q_table_2d = q_table.reshape(4, 16)
|
136 |
+
for row_index, row in enumerate(q_table_2d, start=1):
|
137 |
+
screen.move(row_index, 14)
|
138 |
+
for col_index, col in enumerate(row):
|
139 |
+
action = col_index % 4
|
140 |
+
char = action_to_char(action)
|
141 |
+
screen.addstr(f" {col:.2f}", curses.color_pair(color_pair[char]))
|
142 |
+
if action == 3:
|
143 |
+
screen.addstr(" ", curses.color_pair(color_pair[" "]))
|
144 |
+
|
145 |
+
# Redraw the game screen (but in the curses memory buffer only).
|
146 |
+
screen.noutrefresh()
|
147 |
+
|
148 |
+
def ansi_frame_to_board(self, frame_string):
|
149 |
+
parts = frame_string.split("\n")
|
150 |
+
board = []
|
151 |
+
p = "\x1b[41m"
|
152 |
+
for part in parts[1:]:
|
153 |
+
if len(part):
|
154 |
+
row = re.findall(r"S|F|H|G", part)
|
155 |
+
try:
|
156 |
+
row[part.index(p)] = "P"
|
157 |
+
except:
|
158 |
+
pass
|
159 |
+
board.append(row)
|
160 |
+
|
161 |
+
return np.array(board)
|
162 |
+
|
163 |
+
def render(self, title=None, q_table=None):
|
164 |
+
if self.render_mode == "curses":
|
165 |
+
frame = self._render_text()
|
166 |
+
|
167 |
+
board = self.ansi_frame_to_board(frame)
|
168 |
+
self.render_ncurses_ui(
|
169 |
+
self.curses_screen, board, self.curses_color_pairs, title, q_table
|
170 |
+
)
|
171 |
+
|
172 |
+
# Show the screen to the user.
|
173 |
+
curses.doupdate()
|
174 |
+
return board
|
175 |
+
|
176 |
+
return super().render()
|
177 |
+
|
178 |
+
def get_expected_new_state_for_action(self, action):
|
179 |
+
return self.P[self.s][action][1][1]
|
main.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import time
|
2 |
+
import curses
|
3 |
+
import numpy as np
|
4 |
+
import warnings
|
5 |
+
from environment import FrozenLakeEnvCustom
|
6 |
+
from agent import QLearningAgent
|
7 |
+
|
8 |
+
warnings.filterwarnings("ignore")
|
9 |
+
|
10 |
+
|
11 |
+
def train_agent(
|
12 |
+
n_training_episodes,
|
13 |
+
min_epsilon,
|
14 |
+
max_epsilon,
|
15 |
+
decay_rate,
|
16 |
+
env,
|
17 |
+
max_steps,
|
18 |
+
agent,
|
19 |
+
learning_rate,
|
20 |
+
gamma,
|
21 |
+
use_frame_delay,
|
22 |
+
):
|
23 |
+
for episode in range(n_training_episodes + 1):
|
24 |
+
|
25 |
+
# Reduce epsilon (because we need less and less exploration)
|
26 |
+
epsilon = min_epsilon + (max_epsilon - min_epsilon) * np.exp(
|
27 |
+
-decay_rate * episode
|
28 |
+
)
|
29 |
+
state, info = env.reset()
|
30 |
+
done = False
|
31 |
+
for step in range(max_steps):
|
32 |
+
|
33 |
+
# Choose the action At using epsilon greedy policy
|
34 |
+
action = agent.epsilon_greedy_policy(state, epsilon)
|
35 |
+
|
36 |
+
# Take action At and observe Rt+1 and St+1
|
37 |
+
# Take the action (a) and observe the outcome state(s') and reward (r)
|
38 |
+
new_state, reward, done, truncated, info = env.step(action)
|
39 |
+
agent.update_q_table(state, action, reward, gamma, learning_rate, new_state)
|
40 |
+
|
41 |
+
env.render(
|
42 |
+
title=f"Training: {episode}/{n_training_episodes}",
|
43 |
+
q_table=agent.q_table,
|
44 |
+
)
|
45 |
+
|
46 |
+
if use_frame_delay:
|
47 |
+
time.sleep(0.01)
|
48 |
+
|
49 |
+
if done:
|
50 |
+
break
|
51 |
+
|
52 |
+
state = new_state
|
53 |
+
return agent
|
54 |
+
|
55 |
+
|
56 |
+
def evaluate_agent(env, max_steps, n_eval_episodes, agent, seed, use_frame_delay):
|
57 |
+
successful_episodes = []
|
58 |
+
episodes_slips = []
|
59 |
+
for episode in range(n_eval_episodes + 1):
|
60 |
+
if seed:
|
61 |
+
state, info = env.reset(seed=seed[episode])
|
62 |
+
else:
|
63 |
+
state, info = env.reset()
|
64 |
+
done = False
|
65 |
+
total_rewards_ep = 0
|
66 |
+
|
67 |
+
slips = []
|
68 |
+
for step in range(max_steps):
|
69 |
+
|
70 |
+
# Take the action (index) that have the maximum expected future reward given that state
|
71 |
+
action = agent.greedy_policy(state)
|
72 |
+
|
73 |
+
expected_new_state = env.get_expected_new_state_for_action(action)
|
74 |
+
new_state, reward, done, truncated, info = env.step(action)
|
75 |
+
total_rewards_ep += reward
|
76 |
+
|
77 |
+
if expected_new_state != new_state:
|
78 |
+
slips.append((step, action, expected_new_state, new_state))
|
79 |
+
|
80 |
+
if reward != 0:
|
81 |
+
successful_episodes.append(episode)
|
82 |
+
|
83 |
+
env.render(
|
84 |
+
title=f"Evaluating: {episode}/{n_eval_episodes} | Slips: {len(slips)}",
|
85 |
+
q_table=agent.q_table,
|
86 |
+
)
|
87 |
+
episodes_slips.append(len(slips))
|
88 |
+
|
89 |
+
if use_frame_delay:
|
90 |
+
time.sleep(0.01)
|
91 |
+
|
92 |
+
if done:
|
93 |
+
break
|
94 |
+
state = new_state
|
95 |
+
|
96 |
+
mean_slips = np.mean(episodes_slips)
|
97 |
+
return successful_episodes, mean_slips
|
98 |
+
|
99 |
+
|
100 |
+
def main(screen):
|
101 |
+
|
102 |
+
# Training parameters
|
103 |
+
n_training_episodes = 2000 # Total training episodes
|
104 |
+
learning_rate = 0.1 # Learning rate
|
105 |
+
|
106 |
+
# Evaluation parameters
|
107 |
+
n_eval_episodes = 100 # Total number of test episodes
|
108 |
+
|
109 |
+
# Environment parameters
|
110 |
+
max_steps = 99 # Max steps per episode
|
111 |
+
gamma = 0.99 # Discounting rate
|
112 |
+
eval_seed = [] # The evaluation seed of the environment
|
113 |
+
|
114 |
+
# Exploration parameters
|
115 |
+
max_epsilon = 1.0 # Exploration probability at start
|
116 |
+
min_epsilon = 0.05 # Minimum exploration probability
|
117 |
+
decay_rate = 0.0005 # Exponential decay rate for exploration prob
|
118 |
+
|
119 |
+
use_frame_delay = False
|
120 |
+
|
121 |
+
env = FrozenLakeEnvCustom(map_name="4x4", is_slippery=True, render_mode="curses")
|
122 |
+
|
123 |
+
agent = QLearningAgent(env)
|
124 |
+
agent = train_agent(
|
125 |
+
n_training_episodes,
|
126 |
+
min_epsilon,
|
127 |
+
max_epsilon,
|
128 |
+
decay_rate,
|
129 |
+
env,
|
130 |
+
max_steps,
|
131 |
+
agent,
|
132 |
+
learning_rate,
|
133 |
+
gamma,
|
134 |
+
use_frame_delay,
|
135 |
+
)
|
136 |
+
|
137 |
+
successful_episodes, mean_slips = evaluate_agent(
|
138 |
+
env, max_steps, n_eval_episodes, agent, eval_seed, use_frame_delay
|
139 |
+
)
|
140 |
+
|
141 |
+
env_curses_screen = env.curses_screen
|
142 |
+
env_curses_screen.addstr(
|
143 |
+
5,
|
144 |
+
2,
|
145 |
+
f"Successful episodes: {len(successful_episodes)}/{n_eval_episodes} | Avg slips: {mean_slips:.2f}\n\n",
|
146 |
+
)
|
147 |
+
env_curses_screen.noutrefresh()
|
148 |
+
curses.doupdate()
|
149 |
+
time.sleep(10)
|
150 |
+
|
151 |
+
|
152 |
+
if __name__ == "__main__":
|
153 |
+
|
154 |
+
# Reset the terminal state after using curses
|
155 |
+
# Call main("") instead if you want to leave the final state of the environment
|
156 |
+
# on the terminal
|
157 |
+
curses.wrapper(main)
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cloudpickle==2.2.0
|
2 |
+
gym==0.26.2
|
3 |
+
gym-notices==0.0.8
|
4 |
+
numpy==1.24.0
|
5 |
+
pygame==2.1.0
|
6 |
+
tqdm==4.64.1
|