First commit
Browse files- .gitignore +8 -0
- app.py +168 -0
- example.py +26 -0
- interpretable.py +47 -0
- requirements.txt +17 -0
- utils.py +178 -0
.gitignore
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
datasets
|
2 |
+
videos
|
3 |
+
figures
|
4 |
+
*.log
|
5 |
+
*.png
|
6 |
+
mujoco210
|
7 |
+
.DS_Store
|
8 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
|
4 |
+
import gymnasium as gym
|
5 |
+
import numpy as np
|
6 |
+
from gymnasium.wrappers import RecordVideo
|
7 |
+
from moviepy.video.compositing.concatenate import concatenate_videoclips
|
8 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
9 |
+
from sympy import latex
|
10 |
+
|
11 |
+
from interpretable import InterpretablePolicyExtractor
|
12 |
+
from utils import generate_dataset_from_expert, rollouts
|
13 |
+
import matplotlib.pyplot as plt
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
import gradio as gr
|
18 |
+
import sys
|
19 |
+
|
20 |
+
intro = """
|
21 |
+
# Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢
|
22 |
+
|
23 |
+
Waris Radji<sup>1</sup>, Corentin Léger<sup>2</sup>, Hector Kohler<sup>1</sup>
|
24 |
+
<small><sup>1</sup>[Inria, team Scool](https://team.inria.fr/scool/) <sup>2</sup>[Inria, team Flowers](https://flowers.inria.fr/)</small>
|
25 |
+
|
26 |
+
|
27 |
+
|
28 |
+
In this demo, we showcase a method to make a trained Reinforcement Learning (RL) policy interpretable using the Kolmogorov-Arnold Network (KAN). The process involves transferring the knowledge from a pre-trained RL policy to a KAN. We achieve this by training the KAN to map actions from observations obtained from trajectories of the pre-trained policy.
|
29 |
+
|
30 |
+
## Procedure
|
31 |
+
|
32 |
+
- Train the KAN using observations from trajectories generated by a pre-trained RL policy, the KAN learns to map observations to corresponding actions.
|
33 |
+
- Apply symbolic regression algorithms to the KAN's learned mapping.
|
34 |
+
- Extract an interpretable policy expressed in symbolic form.
|
35 |
+
|
36 |
+
For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
|
37 |
+
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
|
38 |
+
"""
|
39 |
+
|
40 |
+
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
|
41 |
+
|
42 |
+
|
43 |
+
class Logger:
|
44 |
+
def __init__(self, filename):
|
45 |
+
self.terminal = sys.stdout
|
46 |
+
self.log = open(filename, "w")
|
47 |
+
|
48 |
+
def write(self, message):
|
49 |
+
self.terminal.write(message)
|
50 |
+
self.log.write(message)
|
51 |
+
|
52 |
+
def flush(self):
|
53 |
+
self.terminal.flush()
|
54 |
+
self.log.flush()
|
55 |
+
|
56 |
+
def isatty(self):
|
57 |
+
return False
|
58 |
+
|
59 |
+
|
60 |
+
sys.stdout = Logger("output.log")
|
61 |
+
sys.stderr = Logger("output.log")
|
62 |
+
|
63 |
+
|
64 |
+
def read_logs():
|
65 |
+
sys.stdout.flush()
|
66 |
+
with open("output.log", "r") as f:
|
67 |
+
return f.read()
|
68 |
+
|
69 |
+
|
70 |
+
if __name__ == "__main__":
|
71 |
+
torch.set_default_dtype(torch.float32)
|
72 |
+
dataset_path = None
|
73 |
+
ipe = None
|
74 |
+
env_name = None
|
75 |
+
|
76 |
+
def load_video_and_dataset(_env_name):
|
77 |
+
global dataset_path
|
78 |
+
global env_name
|
79 |
+
env_name = _env_name
|
80 |
+
|
81 |
+
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
82 |
+
return video_path, gr.Button("Compute the symbolic policy!", interactive=True)
|
83 |
+
|
84 |
+
|
85 |
+
def parse_integer_list(input_str):
|
86 |
+
if not input_str or input_str.isspace():
|
87 |
+
return None
|
88 |
+
|
89 |
+
elements = input_str.split(',')
|
90 |
+
|
91 |
+
try:
|
92 |
+
int_list = tuple([int(elem.strip()) for elem in elements])
|
93 |
+
return int_list
|
94 |
+
except ValueError:
|
95 |
+
return False
|
96 |
+
|
97 |
+
def extract_interpretable_policy(env_name, kan_widths):
|
98 |
+
global ipe
|
99 |
+
|
100 |
+
widths = parse_integer_list(kan_widths)
|
101 |
+
if kan_widths is False:
|
102 |
+
gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.")
|
103 |
+
widths = None
|
104 |
+
|
105 |
+
ipe = InterpretablePolicyExtractor(env_name, widths)
|
106 |
+
ipe.train_from_dataset(dataset_path, steps=50)
|
107 |
+
|
108 |
+
ipe.policy.prune()
|
109 |
+
ipe.policy.plot(mask=True, scale=5)
|
110 |
+
|
111 |
+
fig = plt.gcf()
|
112 |
+
fig.canvas.draw()
|
113 |
+
return np.array(fig.canvas.renderer.buffer_rgba())
|
114 |
+
|
115 |
+
def symbolic_policy():
|
116 |
+
global ipe
|
117 |
+
global env_name
|
118 |
+
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
|
119 |
+
ipe.policy.auto_symbolic(lib=lib)
|
120 |
+
env = gym.make(env_name, render_mode="rgb_array")
|
121 |
+
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
|
122 |
+
|
123 |
+
rollouts(env, ipe.forward, 2)
|
124 |
+
|
125 |
+
video_path = os.path.join("videos", f"kan-{env_name}.mp4")
|
126 |
+
video_files = glob.glob(os.path.join("videos", f"kan-{env_name}-episode*.mp4"))
|
127 |
+
clips = [VideoFileClip(file) for file in video_files]
|
128 |
+
final_clip = concatenate_videoclips(clips)
|
129 |
+
final_clip.write_videofile(video_path, codec="libx264", fps=24)
|
130 |
+
|
131 |
+
symbolic_formula = f"### The symbolic formula of the policy is:"
|
132 |
+
formulas = ipe.policy.symbolic_formula()[0]
|
133 |
+
for i, formula in enumerate(formulas):
|
134 |
+
symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$"
|
135 |
+
if ipe._action_is_discrete:
|
136 |
+
symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$"
|
137 |
+
|
138 |
+
return video_path, symbolic_formula
|
139 |
+
|
140 |
+
|
141 |
+
css = """
|
142 |
+
#formula {overflow-x: auto!important};
|
143 |
+
"""
|
144 |
+
|
145 |
+
with gr.Blocks(theme='gradio/monochrome', css=css) as app:
|
146 |
+
gr.Markdown(intro)
|
147 |
+
|
148 |
+
with gr.Row():
|
149 |
+
with gr.Column():
|
150 |
+
gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
|
151 |
+
choice = gr.Dropdown(envs, label="Environment name")
|
152 |
+
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
|
153 |
+
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
|
154 |
+
button = gr.Button("Compute the symbolic policy!", interactive=False)
|
155 |
+
with gr.Column():
|
156 |
+
gr.Markdown("### Symbolic policy extraction")
|
157 |
+
kan_architecture = gr.Image(interactive=False, label="KAN architecture")
|
158 |
+
sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True)
|
159 |
+
sym_formula = gr.Markdown(elem_id="formula")
|
160 |
+
with gr.Accordion("See logs"):
|
161 |
+
logs = gr.Textbox(label="Logs", interactive=False)
|
162 |
+
choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button])
|
163 |
+
button.click(extract_interpretable_policy, inputs=[choice, kan_widths], outputs=[kan_architecture]).then(
|
164 |
+
symbolic_policy, inputs=[], outputs=[sym_video, sym_formula]
|
165 |
+
)
|
166 |
+
app.load(read_logs, None, logs, every=1)
|
167 |
+
|
168 |
+
app.launch()
|
example.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gym
|
2 |
+
from gym.wrappers import RecordVideo
|
3 |
+
from matplotlib import pyplot as plt
|
4 |
+
|
5 |
+
from interpretable.interpretable import InterpretablePolicyExtractor
|
6 |
+
from interpretable.utils import generate_dataset_from_expert, rollouts
|
7 |
+
|
8 |
+
if __name__ == "__main__":
|
9 |
+
env_name = "CartPole-v1"
|
10 |
+
dataset_path = generate_dataset_from_expert("ppo", env_name, force=True)
|
11 |
+
ipe = InterpretablePolicyExtractor(env_name)
|
12 |
+
results = ipe.train_from_dataset(dataset_path)
|
13 |
+
ipe.policy.prune()
|
14 |
+
ipe.policy.plot(mask=True)
|
15 |
+
plt.savefig("kan-policy.png")
|
16 |
+
|
17 |
+
env = gym.make(env_name, render_mode="rgb_array")
|
18 |
+
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}")
|
19 |
+
|
20 |
+
ipe.policy.auto_symbolic()
|
21 |
+
|
22 |
+
ipe.policy.plot(mask=True)
|
23 |
+
plt.savefig("sym-policy.png")
|
24 |
+
print(ipe.policy.symbolic_formula())
|
25 |
+
|
26 |
+
rollouts(env, ipe.forward, 2)
|
interpretable.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Dict, Tuple, Optional, Callable, Union
|
3 |
+
import gymnasium as gym
|
4 |
+
from kan import KAN
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def extract_dim(space: gym.Space):
|
9 |
+
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1:
|
10 |
+
return space.shape[0], False
|
11 |
+
elif isinstance(space, gym.spaces.Discrete):
|
12 |
+
return space.n, True
|
13 |
+
else:
|
14 |
+
raise NotImplementedError(f"There is no support for space {space}.")
|
15 |
+
|
16 |
+
|
17 |
+
class InterpretablePolicyExtractor:
|
18 |
+
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
|
19 |
+
|
20 |
+
def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None):
|
21 |
+
self.env = gym.make(env_name)
|
22 |
+
if hidden_widths is None:
|
23 |
+
hidden_widths = []
|
24 |
+
observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space)
|
25 |
+
action_dim, self._action_is_discrete = extract_dim(self.env.action_space)
|
26 |
+
self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim])
|
27 |
+
self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss()
|
28 |
+
|
29 |
+
def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20):
|
30 |
+
if isinstance(dataset, str):
|
31 |
+
dataset = torch.load(dataset)
|
32 |
+
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
33 |
+
dataset["train_label"] = dataset["train_label"][:, None]
|
34 |
+
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
|
35 |
+
dataset["test_label"] = dataset["test_label"][:, None]
|
36 |
+
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
37 |
+
|
38 |
+
def forward(self, observation):
|
39 |
+
observation = torch.from_numpy(observation)
|
40 |
+
action = self.policy(observation.unsqueeze(0))
|
41 |
+
if self._action_is_discrete:
|
42 |
+
return action.argmax(axis=-1).squeeze().item()
|
43 |
+
else:
|
44 |
+
return action.squeeze(0).detach().numpy()
|
45 |
+
|
46 |
+
def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int):
|
47 |
+
raise NotImplementedError() # TODO
|
requirements.txt
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gymnasium[box2d,mujoco]~=0.29.1
|
2 |
+
torch~=2.2.2
|
3 |
+
numpy~=1.26.4
|
4 |
+
pykan~=0.0.2
|
5 |
+
tqdm~=4.66.2
|
6 |
+
scikit-learn~=1.4.2
|
7 |
+
matplotlib~=3.8.4
|
8 |
+
moviepy~=1.0.3
|
9 |
+
huggingface_hub
|
10 |
+
gradio
|
11 |
+
huggingface_sb3
|
12 |
+
stable_baselines3
|
13 |
+
rl_zoo3
|
14 |
+
gym
|
15 |
+
shimmy>=0.2.1
|
16 |
+
mujoco-py
|
17 |
+
cpython<3
|
utils.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import os
|
3 |
+
import pickle
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
import gymnasium as gym
|
8 |
+
from huggingface_hub.utils import EntryNotFoundError
|
9 |
+
from huggingface_sb3 import load_from_hub
|
10 |
+
from moviepy.video.compositing.concatenate import concatenate_videoclips
|
11 |
+
from moviepy.video.io.VideoFileClip import VideoFileClip
|
12 |
+
from rl_zoo3 import ALGOS
|
13 |
+
from gymnasium.wrappers import RecordVideo
|
14 |
+
from stable_baselines3.common.running_mean_std import RunningMeanStd
|
15 |
+
|
16 |
+
import os
|
17 |
+
import tarfile
|
18 |
+
import urllib.request
|
19 |
+
|
20 |
+
|
21 |
+
def install_mujoco():
|
22 |
+
mujoco_url = "https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz"
|
23 |
+
mujoco_file = "mujoco210-linux-x86_64.tar.gz"
|
24 |
+
mujoco_dir = "mujoco210"
|
25 |
+
|
26 |
+
# Check if the directory already exists
|
27 |
+
if not os.path.exists("mujoco210"):
|
28 |
+
# Download Mujoco if not exists
|
29 |
+
print("Downloading Mujoco...")
|
30 |
+
urllib.request.urlretrieve(mujoco_url, mujoco_file)
|
31 |
+
|
32 |
+
# Extract Mujoco
|
33 |
+
print("Extracting Mujoco...")
|
34 |
+
with tarfile.open(mujoco_file, "r:gz") as tar:
|
35 |
+
tar.extractall()
|
36 |
+
|
37 |
+
# Clean up the downloaded tar file
|
38 |
+
os.remove(mujoco_file)
|
39 |
+
|
40 |
+
print("Mujoco installed successfully!")
|
41 |
+
else:
|
42 |
+
print("Mujoco already installed.")
|
43 |
+
|
44 |
+
# Set environment variable MUJOCO_PY_MUJOCO_PATH
|
45 |
+
os.environ["MUJOCO_PY_MUJOCO_PATH"] = os.path.abspath(mujoco_dir)
|
46 |
+
|
47 |
+
ld_library_path = os.environ.get("LD_LIBRARY_PATH", "")
|
48 |
+
mujoco_bin_path = os.path.join(os.path.abspath(mujoco_dir), "bin")
|
49 |
+
if mujoco_bin_path not in ld_library_path:
|
50 |
+
os.environ["LD_LIBRARY_PATH"] = ld_library_path + ":" + mujoco_bin_path
|
51 |
+
|
52 |
+
|
53 |
+
|
54 |
+
class NormalizeObservation(gym.Wrapper):
|
55 |
+
def __init__(self, env: gym.Env, clip_obs: float, obs_rms: RunningMeanStd, epsilon: float):
|
56 |
+
gym.Wrapper.__init__(self, env)
|
57 |
+
self.clip_obs = clip_obs
|
58 |
+
self.obs_rms = obs_rms
|
59 |
+
self.epsilon = epsilon
|
60 |
+
|
61 |
+
def step(self, action):
|
62 |
+
observation, reward, terminated, truncated, info = self.env.step(action)
|
63 |
+
observation = self.normalize(np.array([observation]))[0]
|
64 |
+
return observation, reward, terminated, truncated, info
|
65 |
+
|
66 |
+
def reset(self, **kwargs):
|
67 |
+
observation, info = self.env.reset(**kwargs)
|
68 |
+
return self.normalize(np.array([observation]))[0], info
|
69 |
+
|
70 |
+
def normalize(self, obs):
|
71 |
+
return np.clip((obs - self.obs_rms.mean) / np.sqrt(self.obs_rms.var + self.epsilon), -self.clip_obs, self.clip_obs)
|
72 |
+
|
73 |
+
|
74 |
+
class CreateDataset(gym.Wrapper):
|
75 |
+
def __init__(self, env: gym.Env):
|
76 |
+
gym.Wrapper.__init__(self, env)
|
77 |
+
self.observations = []
|
78 |
+
self.actions = []
|
79 |
+
self.last_observation = None
|
80 |
+
|
81 |
+
def step(self, action):
|
82 |
+
self.observations.append(self.last_observation)
|
83 |
+
self.actions.append(action)
|
84 |
+
observation, reward, terminated, truncated, info = self.env.step(action)
|
85 |
+
self.last_observation = observation
|
86 |
+
return observation, reward, terminated, truncated, info
|
87 |
+
|
88 |
+
def reset(self, **kwargs):
|
89 |
+
observation, info = self.env.reset(**kwargs)
|
90 |
+
self.last_observation = observation
|
91 |
+
return observation, info
|
92 |
+
|
93 |
+
def get_dataset(self):
|
94 |
+
if isinstance(self.env.action_space, gym.spaces.Box) and self.env.action_space.shape != (1,):
|
95 |
+
actions = np.vstack(self.actions)
|
96 |
+
else:
|
97 |
+
actions = np.hstack(self.actions)
|
98 |
+
return np.vstack(self.observations), actions
|
99 |
+
|
100 |
+
|
101 |
+
def rollouts(env, policy, num_episodes=1):
|
102 |
+
for episode in range(num_episodes):
|
103 |
+
done = False
|
104 |
+
observation, _ = env.reset()
|
105 |
+
while not done:
|
106 |
+
action = policy(observation)
|
107 |
+
observation, reward, terminated, truncated, _ = env.step(action)
|
108 |
+
done = terminated or truncated
|
109 |
+
env.close()
|
110 |
+
|
111 |
+
|
112 |
+
def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
|
113 |
+
if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
|
114 |
+
install_mujoco()
|
115 |
+
if env_name == "Swimmer-v4":
|
116 |
+
env_name = "Swimmer-v3"
|
117 |
+
elif env_name == "Hopper-v4":
|
118 |
+
env_name = "Hopper-v3"
|
119 |
+
dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
|
120 |
+
video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
|
121 |
+
if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
|
122 |
+
return dataset_path, video_path
|
123 |
+
repo_id = f"sb3/{algo}-{env_name}"
|
124 |
+
policy_file = f"{algo}-{env_name}.zip"
|
125 |
+
|
126 |
+
expert_path = load_from_hub(repo_id, policy_file)
|
127 |
+
try:
|
128 |
+
vec_normalize_path = load_from_hub(repo_id, "vec_normalize.pkl")
|
129 |
+
with open(vec_normalize_path, "rb") as f:
|
130 |
+
vec_normalize = pickle.load(f)
|
131 |
+
if vec_normalize.norm_obs:
|
132 |
+
vec_normalize_params = {"clip_obs": vec_normalize.clip_obs, "obs_rms": vec_normalize.obs_rms, "epsilon": vec_normalize.epsilon}
|
133 |
+
else:
|
134 |
+
vec_normalize_params = None
|
135 |
+
except EntryNotFoundError:
|
136 |
+
vec_normalize_params = None
|
137 |
+
|
138 |
+
expert = ALGOS[algo].load(expert_path)
|
139 |
+
train_env = gym.make(env_name)
|
140 |
+
train_env = CreateDataset(train_env)
|
141 |
+
if vec_normalize_params is not None:
|
142 |
+
train_env = NormalizeObservation(train_env, **vec_normalize_params)
|
143 |
+
test_env = gym.make(env_name, render_mode="rgb_array")
|
144 |
+
test_env = CreateDataset(test_env)
|
145 |
+
if vec_normalize_params is not None:
|
146 |
+
test_env = NormalizeObservation(test_env, **vec_normalize_params)
|
147 |
+
test_env = RecordVideo(test_env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"{algo}-{env_name}")
|
148 |
+
|
149 |
+
def policy(obs):
|
150 |
+
return expert.predict(obs, deterministic=True)[0]
|
151 |
+
|
152 |
+
os.makedirs("videos", exist_ok=True)
|
153 |
+
rollouts(train_env, policy, num_train_episodes)
|
154 |
+
rollouts(test_env, policy, num_test_episodes)
|
155 |
+
|
156 |
+
train_observations, train_actions = train_env.get_dataset()
|
157 |
+
test_observations, test_actions = test_env.get_dataset()
|
158 |
+
|
159 |
+
dataset = {
|
160 |
+
"train_input": torch.from_numpy(train_observations),
|
161 |
+
"test_input": torch.from_numpy(test_observations),
|
162 |
+
"train_label": torch.from_numpy(train_actions),
|
163 |
+
"test_label": torch.from_numpy(test_actions)
|
164 |
+
}
|
165 |
+
|
166 |
+
os.makedirs("datasets", exist_ok=True)
|
167 |
+
torch.save(dataset, dataset_path)
|
168 |
+
|
169 |
+
video_files = glob.glob(os.path.join("videos", f"{algo}-{env_name}-episode*.mp4"))
|
170 |
+
clips = [VideoFileClip(file) for file in video_files]
|
171 |
+
final_clip = concatenate_videoclips(clips)
|
172 |
+
final_clip.write_videofile(video_path, codec="libx264", fps=24)
|
173 |
+
|
174 |
+
return dataset_path, video_path
|
175 |
+
|
176 |
+
|
177 |
+
if __name__ == "__main__":
|
178 |
+
generate_dataset_from_expert("ppo", "CartPole-v1", force=True)
|