|
import glob |
|
import os |
|
|
|
import gymnasium as gym |
|
import numpy as np |
|
from gymnasium.wrappers import RecordVideo |
|
from moviepy.video.compositing.concatenate import concatenate_videoclips |
|
from moviepy.video.io.VideoFileClip import VideoFileClip |
|
from sympy import latex |
|
|
|
from interpretable import InterpretablePolicyExtractor |
|
from utils import generate_dataset_from_expert, rollouts |
|
import matplotlib.pyplot as plt |
|
|
|
import torch |
|
|
|
import gradio as gr |
|
import sys |
|
|
|
intro = """ |
|
# Making RL Policy Interpretable with Kolmogorov-Arnold Network 🧠 ➙ 🔢 |
|
|
|
Waris Radji<sup>1</sup>, Corentin Léger<sup>2</sup>, Hector Kohler<sup>1</sup> |
|
<small><sup>1</sup>[Inria, team Scool](https://team.inria.fr/scool/) <sup>2</sup>[Inria, team Flowers](https://flowers.inria.fr/)</small> |
|
|
|
|
|
|
|
In this demo, we showcase a method to make a trained Reinforcement Learning (RL) policy interpretable using the Kolmogorov-Arnold Network (KAN). The process involves transferring the knowledge from a pre-trained RL policy to a KAN. We achieve this by training the KAN to map actions from observations obtained from trajectories of the pre-trained policy. |
|
|
|
## Procedure |
|
|
|
- Train the KAN using observations from trajectories generated by a pre-trained RL policy, the KAN learns to map observations to corresponding actions. |
|
- Apply symbolic regression algorithms to the KAN's learned mapping. |
|
- Extract an interpretable policy expressed in symbolic form. |
|
|
|
For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan). |
|
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl). |
|
""" |
|
|
|
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"] |
|
|
|
|
|
class Logger: |
|
def __init__(self, filename): |
|
self.terminal = sys.stdout |
|
self.log = open(filename, "w") |
|
|
|
def write(self, message): |
|
self.terminal.write(message) |
|
self.log.write(message) |
|
|
|
def flush(self): |
|
self.terminal.flush() |
|
self.log.flush() |
|
|
|
def isatty(self): |
|
return False |
|
|
|
|
|
sys.stdout = Logger("output.log") |
|
sys.stderr = Logger("output.log") |
|
|
|
|
|
def read_logs(): |
|
sys.stdout.flush() |
|
with open("output.log", "r") as f: |
|
return f.read() |
|
|
|
|
|
if __name__ == "__main__": |
|
torch.set_default_dtype(torch.float32) |
|
dataset_path = None |
|
ipe = None |
|
env_name = None |
|
|
|
def load_video_and_dataset(_env_name): |
|
global dataset_path |
|
global env_name |
|
env_name = _env_name |
|
|
|
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3) |
|
return video_path, gr.Button("Compute the symbolic policy!", interactive=True) |
|
|
|
|
|
def parse_integer_list(input_str): |
|
if not input_str or input_str.isspace(): |
|
return None |
|
|
|
elements = input_str.split(',') |
|
|
|
try: |
|
int_list = tuple([int(elem.strip()) for elem in elements]) |
|
return int_list |
|
except ValueError: |
|
return False |
|
|
|
def extract_interpretable_policy(env_name, kan_widths): |
|
global ipe |
|
|
|
widths = parse_integer_list(kan_widths) |
|
if kan_widths is False: |
|
gr.Warning(f"Please enter widths {kan_widths} in the right format... The current run is executed with no hidden layer.") |
|
widths = None |
|
|
|
ipe = InterpretablePolicyExtractor(env_name, widths) |
|
ipe.train_from_dataset(dataset_path, steps=50) |
|
|
|
ipe.policy.prune() |
|
ipe.policy.plot(mask=True, scale=5) |
|
|
|
fig = plt.gcf() |
|
fig.canvas.draw() |
|
return np.array(fig.canvas.renderer.buffer_rgba()) |
|
|
|
def symbolic_policy(): |
|
global ipe |
|
global env_name |
|
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs'] |
|
ipe.policy.auto_symbolic(lib=lib) |
|
env = gym.make(env_name, render_mode="rgb_array") |
|
env = RecordVideo(env, video_folder="videos", episode_trigger=lambda x: True, name_prefix=f"kan-{env_name}") |
|
|
|
rollouts(env, ipe.forward, 2) |
|
|
|
video_path = os.path.join("videos", f"kan-{env_name}.mp4") |
|
video_files = glob.glob(os.path.join("videos", f"kan-{env_name}-episode*.mp4")) |
|
clips = [VideoFileClip(file) for file in video_files] |
|
final_clip = concatenate_videoclips(clips) |
|
final_clip.write_videofile(video_path, codec="libx264", fps=24) |
|
|
|
symbolic_formula = f"### The symbolic formula of the policy is:" |
|
formulas = ipe.policy.symbolic_formula()[0] |
|
for i, formula in enumerate(formulas): |
|
symbolic_formula += "\n$$ a_" + str(i) + "=" + latex(formula) + "$$" |
|
if ipe._action_is_discrete: |
|
symbolic_formula += "\n" + r"$$ a = \underset{i}{\mathrm{argmax}} \ a_i.$$" |
|
|
|
return video_path, symbolic_formula |
|
|
|
|
|
css = """ |
|
#formula {overflow-x: auto!important}; |
|
""" |
|
|
|
with gr.Blocks(theme='gradio/monochrome', css=css) as app: |
|
gr.Markdown(intro) |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))") |
|
choice = gr.Dropdown(envs, label="Environment name") |
|
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True) |
|
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.") |
|
button = gr.Button("Compute the symbolic policy!", interactive=False) |
|
with gr.Column(): |
|
gr.Markdown("### Symbolic policy extraction") |
|
kan_architecture = gr.Image(interactive=False, label="KAN architecture") |
|
sym_video = gr.Video(label="Symbolic policy video", interactive=False, autoplay=True) |
|
sym_formula = gr.Markdown(elem_id="formula") |
|
with gr.Accordion("See logs"): |
|
logs = gr.Textbox(label="Logs", interactive=False) |
|
choice.input(load_video_and_dataset, inputs=[choice], outputs=[expert_video, button]) |
|
button.click(extract_interpretable_policy, inputs=[choice, kan_widths], outputs=[kan_architecture]).then( |
|
symbolic_policy, inputs=[], outputs=[sym_video, sym_formula] |
|
) |
|
app.load(read_logs, None, logs, every=1) |
|
|
|
app.launch() |
|
|