riiswa commited on
Commit
1240765
1 Parent(s): d9d70e0

Add warning on mujoco using

Browse files
Files changed (5) hide show
  1. README.md +7 -0
  2. app.py +3 -1
  3. packages.txt +2 -1
  4. requirements.txt +1 -1
  5. utils.py +0 -4
README.md CHANGED
@@ -11,3 +11,10 @@ license: mit
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
14
+
15
+ ### Application demo :
16
+
17
+ - Choose a RL environment from the gymnasium library. A policy from a pre-trained Proximal Policy Optimization (PPO) agent will automatically be loaded, which generates an expert dataset and videos of the agent's performance in the selected environment.
18
+ - Click the "Compute Symbolic Policy" button to train a KAN policy on the expert dataset. Once it is done, you can visualize the KAN network and watch videos of the KAN agent's performance in the selected environment !
19
+
20
+ <img alt="Interpretability app demo" src="demo/app_demo.gif">
app.py CHANGED
@@ -36,7 +36,7 @@ For more information about KAN you can read the [paper](https://arxiv.org/abs/24
36
  To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
37
  """
38
 
39
- envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v4", "Hopper-v4"]
40
 
41
 
42
  if __name__ == "__main__":
@@ -45,6 +45,8 @@ if __name__ == "__main__":
45
  def load_video_and_dataset(_env_name):
46
  env_name = _env_name
47
 
 
 
48
  dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
49
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
50
  "dataset_path": dataset_path,
 
36
  To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
37
  """
38
 
39
+ envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3"]
40
 
41
 
42
  if __name__ == "__main__":
 
45
  def load_video_and_dataset(_env_name):
46
  env_name = _env_name
47
 
48
+ if env_name.startswith("Swimmer") or env_name.startswith("Hopper-v3"):
49
+ gr.Warning("We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework.")
50
  dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
51
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
52
  "dataset_path": dataset_path,
packages.txt CHANGED
@@ -3,4 +3,5 @@ libgl1-mesa-glx
3
  libglew-dev
4
  libosmesa6-dev
5
  software-properties-common
6
- patchelf
 
 
3
  libglew-dev
4
  libosmesa6-dev
5
  software-properties-common
6
+ patchelf
7
+ swig
requirements.txt CHANGED
@@ -13,4 +13,4 @@ stable_baselines3
13
  rl_zoo3
14
  gym
15
  shimmy>=0.2.1
16
- mujoco-py
 
13
  rl_zoo3
14
  gym
15
  shimmy>=0.2.1
16
+ free-mujoco-py
utils.py CHANGED
@@ -112,10 +112,6 @@ def rollouts(env, policy, num_episodes=1):
112
  def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
113
  if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
114
  install_mujoco()
115
- if env_name == "Swimmer-v4":
116
- env_name = "Swimmer-v3"
117
- elif env_name == "Hopper-v4":
118
- env_name = "Hopper-v3"
119
  dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
120
  video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
121
  if os.path.exists(dataset_path) and os.path.exists(video_path) and not force:
 
112
  def generate_dataset_from_expert(algo, env_name, num_train_episodes=5, num_test_episodes=2, force=False):
113
  if env_name.startswith("Swimmer") or env_name.startswith("Hopper"):
114
  install_mujoco()
 
 
 
 
115
  dataset_path = os.path.join("datasets", f"{algo}-{env_name}.pt")
116
  video_path = os.path.join("videos", f"{algo}-{env_name}.mp4")
117
  if os.path.exists(dataset_path) and os.path.exists(video_path) and not force: