Fix state handling
Browse files
app.py
CHANGED
@@ -15,7 +15,6 @@ import matplotlib.pyplot as plt
|
|
15 |
import torch
|
16 |
|
17 |
import gradio as gr
|
18 |
-
import sys
|
19 |
|
20 |
intro = """
|
21 |
# Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
|
@@ -40,46 +39,18 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
|
|
40 |
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
|
41 |
|
42 |
|
43 |
-
class Logger:
|
44 |
-
def __init__(self, filename):
|
45 |
-
self.terminal = sys.stdout
|
46 |
-
self.log = open(filename, "w")
|
47 |
-
|
48 |
-
def write(self, message):
|
49 |
-
self.terminal.write(message)
|
50 |
-
self.log.write(message)
|
51 |
-
|
52 |
-
def flush(self):
|
53 |
-
self.terminal.flush()
|
54 |
-
self.log.flush()
|
55 |
-
|
56 |
-
def isatty(self):
|
57 |
-
return False
|
58 |
-
|
59 |
-
|
60 |
-
sys.stdout = Logger("output.log")
|
61 |
-
sys.stderr = Logger("output.log")
|
62 |
-
|
63 |
-
|
64 |
-
def read_logs():
|
65 |
-
sys.stdout.flush()
|
66 |
-
with open("output.log", "r") as f:
|
67 |
-
return f.read()
|
68 |
-
|
69 |
-
|
70 |
if __name__ == "__main__":
|
71 |
torch.set_default_dtype(torch.float32)
|
72 |
-
dataset_path = None
|
73 |
-
ipe = None
|
74 |
-
env_name = None
|
75 |
|
76 |
def load_video_and_dataset(_env_name):
|
77 |
-
global dataset_path
|
78 |
-
global env_name
|
79 |
env_name = _env_name
|
80 |
|
81 |
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
82 |
-
return video_path, gr.Button("Compute the symbolic policy!", interactive=True)
|
|
|
|
|
|
|
|
|
83 |
|
84 |
|
85 |
def parse_integer_list(input_str):
|
@@ -94,45 +65,44 @@ if __name__ == "__main__":
|
|
94 |
except ValueError:
|
95 |
return False
|
96 |
|
97 |
-
def extract_interpretable_policy(
|
98 |
-
global ipe
|
99 |
-
|
100 |
widths = parse_integer_list(kan_widths)
|
101 |
if kan_widths is False:
|
102 |
gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
|
103 |
widths = None
|
104 |
|
105 |
-
ipe = InterpretablePolicyExtractor(env_name, widths)
|
106 |
-
ipe.train_from_dataset(dataset_path, steps=
|
107 |
|
108 |
-
ipe.policy.prune()
|
109 |
-
ipe.policy.plot(mask=True, scale=5)
|
110 |
|
111 |
fig = plt.gcf()
|
112 |
fig.canvas.draw()
|
113 |
-
|
|
|
|
|
|
|
114 |
|
115 |
-
def symbolic_policy():
|
116 |
-
global ipe
|
117 |
-
global env_name
|
118 |
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
|
119 |
-
ipe.policy.auto_symbolic(lib=lib)
|
120 |
-
env = gym.make(env_name, render_mode="rgb_array")
|
121 |
-
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
|
122 |
|
123 |
-
rollouts(env, ipe.forward, 2)
|
124 |
|
125 |
-
video_path = os.path.join("videos", f"kan-{env_name}.mp4")
|
126 |
-
video_files = glob.glob(os.path.join("videos", f"kan-{env_name}-episode*.mp4"))
|
127 |
clips = [VideoFileClip(file) for file in video_files]
|
128 |
final_clip = concatenate_videoclips(clips)
|
129 |
final_clip.write_videofile(video_path, codec="libx264", fps=24)
|
130 |
|
131 |
symbolic_formula = f"### The symbolic formula of the policy is:"
|
132 |
-
formulas = ipe.policy.symbolic_formula()[0]
|
133 |
for i, formula in enumerate(formulas):
|
134 |
symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
|
135 |
-
if ipe._action_is_discrete:
|
136 |
symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
|
137 |
|
138 |
return video_path, symbolic_formula
|
@@ -143,6 +113,11 @@ if __name__ == "__main__":
|
|
143 |
"""
|
144 |
|
145 |
with gr.Blocks(theme='gradio/monochrome', css=css) as app:
|
|
|
|
|
|
|
|
|
|
|
146 |
gr.Markdown(intro)
|
147 |
|
148 |
with gr.Row():
|
@@ -151,18 +126,16 @@ if __name__ == "__main__":
|
|
151 |
choice = gr.Dropdown(envs, label="Environment name")
|
152 |
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
|
153 |
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
|
|
|
154 |
button = gr.Button("Compute the symbolic policy!", interactive=False)
|
155 |
with gr.Column():
|
156 |
gr.Markdown("### Symbolic policy extraction")
|
157 |
kan_architecture = gr.Image(interactive=False, label="KAN architecture")
|
158 |
sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
|
159 |
sym_formula = gr.Markdown(elem_id="formula")
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
button.click(extract_interpretable_policy, inputs=[choice, kan_widths], outputs=[kan_architecture]).then(
|
164 |
-
symbolic_policy, inputs=[], outputs=[sym_video, sym_formula]
|
165 |
)
|
166 |
-
app.load(read_logs, None, logs, every=1)
|
167 |
|
168 |
app.launch()
|
|
|
15 |
import torch
|
16 |
|
17 |
import gradio as gr
|
|
|
18 |
|
19 |
intro = """
|
20 |
# Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
|
|
|
39 |
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
|
40 |
|
41 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
42 |
if __name__ == "__main__":
|
43 |
torch.set_default_dtype(torch.float32)
|
|
|
|
|
|
|
44 |
|
45 |
def load_video_and_dataset(_env_name):
|
|
|
|
|
46 |
env_name = _env_name
|
47 |
|
48 |
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
49 |
+
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
50 |
+
"dataset_path": dataset_path,
|
51 |
+
"ipe": None,
|
52 |
+
"env_name": env_name
|
53 |
+
}
|
54 |
|
55 |
|
56 |
def parse_integer_list(input_str):
|
|
|
65 |
except ValueError:
|
66 |
return False
|
67 |
|
68 |
+
def extract_interpretable_policy(kan_widths, epochs, state):
|
|
|
|
|
69 |
widths = parse_integer_list(kan_widths)
|
70 |
if kan_widths is False:
|
71 |
gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
|
72 |
widths = None
|
73 |
|
74 |
+
state["ipe"] = InterpretablePolicyExtractor(state["env_name"], widths)
|
75 |
+
state["ipe"].train_from_dataset(state["dataset_path"], steps=epochs)
|
76 |
|
77 |
+
state["ipe"].policy.prune()
|
78 |
+
state["ipe"].policy.plot(mask=True, scale=5)
|
79 |
|
80 |
fig = plt.gcf()
|
81 |
fig.canvas.draw()
|
82 |
+
kan_architecture = np.array(fig.canvas.renderer.buffer_rgba())
|
83 |
+
plt.close()
|
84 |
+
|
85 |
+
return kan_architecture, state, fig
|
86 |
|
87 |
+
def symbolic_policy(state):
|
|
|
|
|
88 |
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
|
89 |
+
state["ipe"].policy.auto_symbolic(lib=lib)
|
90 |
+
env = gym.make(state["env_name"], render_mode="rgb_array")
|
91 |
+
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"""kan-{state["env_name"]}""")
|
92 |
|
93 |
+
rollouts(env, state["ipe"].forward, 2)
|
94 |
|
95 |
+
video_path = os.path.join("videos", f"""kan-{state["env_name"]}.mp4""")
|
96 |
+
video_files = glob.glob(os.path.join("videos", f"""kan-{state["env_name"]}-episode*.mp4"""))
|
97 |
clips = [VideoFileClip(file) for file in video_files]
|
98 |
final_clip = concatenate_videoclips(clips)
|
99 |
final_clip.write_videofile(video_path, codec="libx264", fps=24)
|
100 |
|
101 |
symbolic_formula = f"### The symbolic formula of the policy is:"
|
102 |
+
formulas = state["ipe"].policy.symbolic_formula()[0]
|
103 |
for i, formula in enumerate(formulas):
|
104 |
symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
|
105 |
+
if state["ipe"]._action_is_discrete:
|
106 |
symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
|
107 |
|
108 |
return video_path, symbolic_formula
|
|
|
113 |
"""
|
114 |
|
115 |
with gr.Blocks(theme='gradio/monochrome', css=css) as app:
|
116 |
+
state = gr.State({
|
117 |
+
"dataset_path": None,
|
118 |
+
"ipe": None,
|
119 |
+
"env_name": None
|
120 |
+
})
|
121 |
gr.Markdown(intro)
|
122 |
|
123 |
with gr.Row():
|
|
|
126 |
choice = gr.Dropdown(envs, label="Environment name")
|
127 |
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
|
128 |
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
|
129 |
+
epochs = gr.Number(value=20, label="KAN training Steps.", minimum=1, maximum=100)
|
130 |
button = gr.Button("Compute the symbolic policy!", interactive=False)
|
131 |
with gr.Column():
|
132 |
gr.Markdown("### Symbolic policy extraction")
|
133 |
kan_architecture = gr.Image(interactive=False, label="KAN architecture")
|
134 |
sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
|
135 |
sym_formula = gr.Markdown(elem_id="formula")
|
136 |
+
choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button, state])
|
137 |
+
button.click(extract_interpretable_policy, inputs=[kan_widths, epochs, state], outputs=[kan_architecture, state]).then(
|
138 |
+
symbolic_policy, inputs=[state], outputs=[sym_video, sym_formula]
|
|
|
|
|
139 |
)
|
|
|
140 |
|
141 |
app.launch()
|