riiswa commited on
Commit
ca85408
1 Parent(s): b118f86

First commit

Browse files
Files changed (6) hide show
  1. .gitignore +8 -0
  2. app.py +168 -0
  3. example.py +26 -0
  4. interpretable.py +47 -0
  5. requirements.txt +17 -0
  6. utils.py +178 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ datasets
2
+ videos
3
+ figures
4
+ *.log
5
+ *.png
6
+ mujoco210
7
+ .DS_Store
8
+ __pycache__/
app.py ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+
4
+ import gymnasium as gym
5
+ import numpy as np
6
+ from gymnasium.wrappers import RecordVideo
7
+ from moviepy.video.compositing.concatenate import concatenate_videoclips
8
+ from moviepy.video.io.VideoFileClip import VideoFileClip
9
+ from sympy import latex
10
+
11
+ from interpretable import InterpretablePolicyExtractor
12
+ from utils import generate_dataset_from_expert, rollouts
13
+ import matplotlib.pyplot as plt
14
+
15
+ import torch
16
+
17
+ import gradio as gr
18
+ import sys
19
+
20
+ intro = """
21
+ # Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
22
+
23
+ Waris Radji<sup>1</sup>, Corentin Léger<sup>2</sup>, Hector Kohler<sup>1</sup>
24
+ <small><sup>1</sup>[Inria, team Scool](https://team.inria.fr/scool/) <sup>2</sup>[Inria, team Flowers](https://flowers.inria.fr/)</small>
25
+
26
+
27
+
28
+ In this demo, we showcase a method to make a trained Reinforcement Learning (RL) policy interpretable using the Kolmogorov-Arnold Network (KAN). The process involves transferring the knowledge from a pre-trained RL policy to a KAN. We achieve this by training the KAN to map actions from observations obtained from trajectories of the pre-trained policy.
29
+
30
+ ## Procedure
31
+
32
+ - Train the KAN using observations from trajectories generated by a pre-trained RL policy, the KAN learns to map observations to corresponding actions.
33
+ - Apply symbolic regression algorithms to the KAN's learned mapping.
34
+ - Extract an interpretable policy expressed in symbolic form.
35
+
36
+ For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
37
+ To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
38
+ """
39
+
40
+ envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
41
+
42
+
43
+ class Logger:
44
+ def __init__(self, filename):
45
+ self.terminal = sys.stdout
46
+ self.log = open(filename, "w")
47
+
48
+ def write(self, message):
49
+ self.terminal.write(message)
50
+ self.log.write(message)
51
+
52
+ def flush(self):
53
+ self.terminal.flush()
54
+ self.log.flush()
55
+
56
+ def isatty(self):
57
+ return False
58
+
59
+
60
+ sys.stdout = Logger("output.log")
61
+ sys.stderr = Logger("output.log")
62
+
63
+
64
+ def read_logs():
65
+ sys.stdout.flush()
66
+ with open("output.log", "r") as f:
67
+ return f.read()
68
+
69
+
70
+ if __name__ == "__main__":
71
+ torch.set_default_dtype(torch.float32)
72
+ dataset_path = None
73
+ ipe = None
74
+ env_name = None
75
+
76
+ def load_video_and_dataset(_env_name):
77
+ global dataset_path
78
+ global env_name
79
+ env_name = _env_name
80
+
81
+ dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
82
+ return video_path, gr.Button("Compute the symbolic policy!", interactive=True)
83
+
84
+
85
+ def parse_integer_list(input_str):
86
+ if not input_str or input_str.isspace():
87
+ return None
88
+
89
+ elements = input_str.split(',')
90
+
91
+ try:
92
+ int_list = tuple([int(elem.strip()) for elem in elements])
93
+ return int_list
94
+ except ValueError:
95
+ return False
96
+
97
+ def extract_interpretable_policy(env_name, kan_widths):
98
+ global ipe
99
+
100
+ widths = parse_integer_list(kan_widths)
101
+ if kan_widths is False:
102
+ gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
103
+ widths = None
104
+
105
+ ipe = InterpretablePolicyExtractor(env_name, widths)
106
+ ipe.train_from_dataset(dataset_path, steps=50)
107
+
108
+ ipe.policy.prune()
109
+ ipe.policy.plot(mask=True, scale=5)
110
+
111
+ fig = plt.gcf()
112
+ fig.canvas.draw()
113
+ return np.array(fig.canvas.renderer.buffer_rgba())
114
+
115
+ def symbolic_policy():
116
+ global ipe
117
+ global env_name
118
+ lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
119
+ ipe.policy.auto_symbolic(lib=lib)
120
+ env = gym.make(env_name, render_mode="rgb_array")
121
+ env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
122
+
123
+ rollouts(env, ipe.forward, 2)
124
+
125
+ video_path = os.path.join("videos", f"kan-{env_name}.mp4")
126
+ video_files = glob.glob(os.path.join("videos", f"kan-{env_name}-episode*.mp4"))
127
+ clips = [VideoFileClip(file) for file in video_files]
128
+ final_clip = concatenate_videoclips(clips)
129
+ final_clip.write_videofile(video_path, codec="libx264", fps=24)
130
+
131
+ symbolic_formula = f"### The symbolic formula of the policy is:"
132
+ formulas = ipe.policy.symbolic_formula()[0]
133
+ for i, formula in enumerate(formulas):
134
+ symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
135
+ if ipe._action_is_discrete:
136
+ symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
137
+
138
+ return video_path, symbolic_formula
139
+
140
+
141
+ css = """
142
+ #formula {overflow-x: auto!important};
143
+ """
144
+
145
+ with gr.Blocks(theme='gradio/monochrome', css=css) as app:
146
+ gr.Markdown(intro)
147
+
148
+ with gr.Row():
149
+ with gr.Column():
150
+ gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
151
+ choice = gr.Dropdown(envs, label="Environment name")
152
+ expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
153
+ kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
154
+ button = gr.Button("Compute the symbolic policy!", interactive=False)
155
+ with gr.Column():
156
+ gr.Markdown("### Symbolic policy extraction")
157
+ kan_architecture = gr.Image(interactive=False, label="KAN architecture")
158
+ sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
159
+ sym_formula = gr.Markdown(elem_id="formula")
160
+ with gr.Accordion("See logs"):
161
+ logs = gr.Textbox(label="Logs", interactive=False)
162
+ choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button])
163
+ button.click(extract_interpretable_policy, inputs=[choice, kan_widths], outputs=[kan_architecture]).then(
164
+ symbolic_policy, inputs=[], outputs=[sym_video, sym_formula]
165
+ )
166
+ app.load(read_logs, None, logs, every=1)
167
+
168
+ app.launch()
example.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gym
2
+ from gym.wrappers import RecordVideo
3
+ from matplotlib import pyplot as plt
4
+
5
+ from interpretable.interpretable import InterpretablePolicyExtractor
6
+ from interpretable.utils import generate_dataset_from_expert, rollouts
7
+
8
+ if __name__ == "__main__":
9
+ env_name = "CartPole-v1"
10
+ dataset_path = generate_dataset_from_expert("ppo", env_name, force=True)
11
+ ipe = InterpretablePolicyExtractor(env_name)
12
+ results = ipe.train_from_dataset(dataset_path)
13
+ ipe.policy.prune()
14
+ ipe.policy.plot(mask=True)
15
+ plt.savefig("kan-policy.png")
16
+
17
+ env = gym.make(env_name, render_mode="rgb_array")
18
+ env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
19
+
20
+ ipe.policy.auto_symbolic()
21
+
22
+ ipe.policy.plot(mask=True)
23
+ plt.savefig("sym-policy.png")
24
+ print(ipe.policy.symbolic_formula())
25
+
26
+ rollouts(env, ipe.forward, 2)
interpretable.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Dict, Tuple, Optional, Callable, Union
3
+ import gymnasium as gym
4
+ from kan import KAN
5
+ import numpy as np
6
+
7
+
8
+ def extract_dim(space: gym.Space):
9
+ if isinstance(space, gym.spaces.Box) and len(space.shape) == 1:
10
+ return space.shape[0], False
11
+ elif isinstance(space, gym.spaces.Discrete):
12
+ return space.n, True
13
+ else:
14
+ raise NotImplementedError(f"There is no support for space {space}.")
15
+
16
+
17
+ class InterpretablePolicyExtractor:
18
+ lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
19
+
20
+ def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None):
21
+ self.env = gym.make(env_name)
22
+ if hidden_widths is None:
23
+ hidden_widths = []
24
+ observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space)
25
+ action_dim, self._action_is_discrete = extract_dim(self.env.action_space)
26
+ self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim])
27
+ self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss()
28
+
29
+ def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20):
30
+ if isinstance(dataset, str):
31
+ dataset = torch.load(dataset)
32
+ if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
33
+ dataset["train_label"] = dataset["train_label"][:, None]
34
+ if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
35
+ dataset["test_label"] = dataset["test_label"][:, None]
36
+ return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
37
+
38
+ def forward(self, observation):
39
+ observation = torch.from_numpy(observation)
40
+ action = self.policy(observation.unsqueeze(0))
41
+ if self._action_is_discrete:
42
+ return action.argmax(axis=-1).squeeze().item()
43
+ else:
44
+ return action.squeeze(0).detach().numpy()
45
+
46
+ def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int):
47
+ raise NotImplementedError() # TODO
requirements.txt ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ gymnasium[box2d,mujoco]~=0.29.1
2
+ torch~=2.2.2
3
+ numpy~=1.26.4
4
+ pykan~=0.0.2
5
+ tqdm~=4.66.2
6
+ scikit-learn~=1.4.2
7
+ matplotlib~=3.8.4
8
+ moviepy~=1.0.3
9
+ huggingface_hub
10
+ gradio
11
+ huggingface_sb3
12
+ stable_baselines3
13
+ rl_zoo3
14
+ gym
15
+ shimmy>=0.2.1
16
+ mujoco-py
17
+ cpython<3
utils.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import glob
2
+ import os
3
+ import pickle
4
+
5
+ import torch
6
+ import numpy as np
7
+ import gymnasium as gym
8
+ from huggingface_hub.utils import EntryNotFoundError
9
+ from huggingface_sb3 import load_from_hub
10
+ from moviepy.video.compositing.concatenate import concatenate_videoclips
11
+ from moviepy.video.io.VideoFileClip import VideoFileClip
12
+ from rl_zoo3 import ALGOS
13
+ from gymnasium.wrappers import RecordVideo
14
+ from stable_baselines3.common.running_mean_std import RunningMeanStd
15
+
16
+ import os
17
+ import tarfile
18
+ import urllib.request
19
+
20
+
21
+ def install_mujoco():
22
+ mujoco_url = "https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz"
23
+ mujoco_file = "mujoco210-linux-x86_64.tar.gz"
24
+ mujoco_dir = "mujoco210"
25
+
26
+ # Check if the directory already exists
27
+ if not os.path.exists("mujoco210"):
28
+ # Download Mujoco if not exists
29
+ print("Downloading Mujoco...")
30
+ urllib.request.urlretrieve(mujoco_url, mujoco_file)
31
+
32
+ # Extract Mujoco
33
+ print("Extracting Mujoco...")
34
+ with tarfile.open(mujoco_file, "r:gz") as tar:
35
+ tar.extractall()
36
+
37
+ # Clean up the downloaded tar file
38
+ os.remove(mujoco_file)
39
+
40
+ print("Mujoco installed successfully!")
41
+ else:
42
+ print("Mujoco already installed.")
43
+
44
+ # Set environment variable MUJOCO_PY_MUJOCO_PATH
45
+ os.environ["MUJOCO_PY_MUJOCO_PATH"] = os.path.abspath(mujoco_dir)
46
+
47
+ ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
48
+ mujoco_bin_path = os.path.join(os.path.abspath(mujoco_dir), "bin")
49
+ if mujoco_bin_path not in ld_library_path:
50
+ os.environ["LD_LIBRARY_PATH"] = ld_library_path + ":" + mujoco_bin_path
51
+
52
+
53
+
54
+ class NormalizeObservation(gym.Wrapper):
55
+ def __init__(self, env: gym.Env, clip_obs: float, obs_rms: RunningMeanStd, epsilon: float):
56
+ gym.Wrapper.__init__(self, env)
57
+ self.clip_obs = clip_obs
58
+ self.obs_rms = obs_rms
59
+ self.epsilon = epsilon
60
+
61
+ def step(self, action):
62
+ observation, reward, terminated, truncated, info = self.env.step(action)
63
+ observation = self.normalize(np.array([observation]))[0]
64
+ return observation, reward, terminated, truncated, info
65
+
66
+ def reset(self, **kwargs):
67
+ observation, info = self.env.reset(**kwargs)
68
+ return self.normalize(np.array([observation]))[0], info
69
+
70
+ def normalize(self, obs):
71
+ return np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
72
+
73
+
74
+ class CreateDataset(gym.Wrapper):
75
+ def __init__(self, env: gym.Env):
76
+ gym.Wrapper.__init__(self, env)
77
+ self.observations = []
78
+ self.actions = []
79
+ self.last_observation = None
80
+
81
+ def step(self, action):
82
+ self.observations.append(self.last_observation)
83
+ self.actions.append(action)
84
+ observation, reward, terminated, truncated, info = self.env.step(action)
85
+ self.last_observation = observation
86
+ return observation, reward, terminated, truncated, info
87
+
88
+ def reset(self, **kwargs):
89
+ observation, info = self.env.reset(**kwargs)
90
+ self.last_observation = observation
91
+ return observation, info
92
+
93
+ def get_dataset(self):
94
+ if isinstance(self.env.action_space, gym.spaces.Box) and self.env.action_space.shape != (1,):
95
+ actions = np.vstack(self.actions)
96
+ else:
97
+ actions = np.hstack(self.actions)
98
+ return np.vstack(self.observations), actions
99
+
100
+
101
+ def rollouts(env, policy, num_episodes=1):
102
+ for episode in range(num_episodes):
103
+ done = False
104
+ observation, _ = env.reset()
105
+ while not done:
106
+ action = policy(observation)
107
+ observation, reward, terminated, truncated, _ = env.step(action)
108
+ done = terminated or truncated
109
+ env.close()
110
+
111
+
112
+ def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
113
+ if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
114
+ install_mujoco()
115
+ if env_name == "Swimmer-v4":
116
+ env_name = "Swimmer-v3"
117
+ elif env_name == "Hopper-v4":
118
+ env_name = "Hopper-v3"
119
+ dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
120
+ video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
121
+ if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
122
+ return dataset_path, video_path
123
+ repo_id = f"sb3/{algo}-{env_name}"
124
+ policy_file = f"{algo}-{env_name}.zip"
125
+
126
+ expert_path = load_from_hub(repo_id, policy_file)
127
+ try:
128
+ vec_normalize_path = load_from_hub(repo_id, "vec_normalize.pkl")
129
+ with open(vec_normalize_path, "rb") as f:
130
+ vec_normalize = pickle.load(f)
131
+ if vec_normalize.norm_obs:
132
+ vec_normalize_params = {"clip_obs": vec_normalize.clip_obs, "obs_rms": vec_normalize.obs_rms, "epsilon": vec_normalize.epsilon}
133
+ else:
134
+ vec_normalize_params = None
135
+ except EntryNotFoundError:
136
+ vec_normalize_params = None
137
+
138
+ expert = ALGOS[algo].load(expert_path)
139
+ train_env = gym.make(env_name)
140
+ train_env = CreateDataset(train_env)
141
+ if vec_normalize_params is not None:
142
+ train_env = NormalizeObservation(train_env, **vec_normalize_params)
143
+ test_env = gym.make(env_name, render_mode="rgb_array")
144
+ test_env = CreateDataset(test_env)
145
+ if vec_normalize_params is not None:
146
+ test_env = NormalizeObservation(test_env, **vec_normalize_params)
147
+ test_env = RecordVideo(test_env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"{algo}-{env_name}")
148
+
149
+ def policy(obs):
150
+ return expert.predict(obs, deterministic=True)[0]
151
+
152
+ os.makedirs("videos", exist_ok=True)
153
+ rollouts(train_env, policy, num_train_episodes)
154
+ rollouts(test_env, policy, num_test_episodes)
155
+
156
+ train_observations, train_actions = train_env.get_dataset()
157
+ test_observations, test_actions = test_env.get_dataset()
158
+
159
+ dataset = {
160
+ "train_input": torch.from_numpy(train_observations),
161
+ "test_input": torch.from_numpy(test_observations),
162
+ "train_label": torch.from_numpy(train_actions),
163
+ "test_label": torch.from_numpy(test_actions)
164
+ }
165
+
166
+ os.makedirs("datasets", exist_ok=True)
167
+ torch.save(dataset, dataset_path)
168
+
169
+ video_files = glob.glob(os.path.join("videos", f"{algo}-{env_name}-episode*.mp4"))
170
+ clips = [VideoFileClip(file) for file in video_files]
171
+ final_clip = concatenate_videoclips(clips)
172
+ final_clip.write_videofile(video_path, codec="libx264", fps=24)
173
+
174
+ return dataset_path, video_path
175
+
176
+
177
+ if __name__ == "__main__":
178
+ generate_dataset_from_expert("ppo", "CartPole-v1", force=True)