riiswa commited on
Commit
e27178f
1 Parent(s): c0e5193

Try to repair Mujoco

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -38,7 +38,7 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
  """
40
 
41
- envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3"]
42
 
43
 
44
  if __name__ == "__main__":
@@ -46,10 +46,11 @@ if __name__ == "__main__":
46
 
47
  def load_video_and_dataset(_env_name):
48
  env_name = _env_name
 
 
 
49
 
50
- if env_name.startswith("Swimmer") or env_name.startswith("Hopper-v3"):
51
- gr.Warning("We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework.")
52
- dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
53
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
54
  "dataset_path": dataset_path,
55
  "ipe": None,
@@ -126,7 +127,7 @@ if __name__ == "__main__":
126
 
127
  with gr.Row():
128
  with gr.Column():
129
- gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
130
  choice = gr.Dropdown(envs, label="Environment name")
131
  expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
132
  kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
 
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
  """
40
 
41
+ envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
42
 
43
 
44
  if __name__ == "__main__":
 
46
 
47
  def load_video_and_dataset(_env_name):
48
  env_name = _env_name
49
+ agent = "ppo"
50
+ if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
51
+ agent ="trpo"
52
 
53
+ dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
 
 
54
  return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
55
  "dataset_path": dataset_path,
56
  "ipe": None,
 
127
 
128
  with gr.Row():
129
  with gr.Column():
130
+ gr.Markdown("### Pretrained policy loading (PPO or TRPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
131
  choice = gr.Dropdown(envs, label="Environment name")
132
  expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
133
  kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")