Try to repair Mujoco
Browse files
app.py
CHANGED
@@ -38,7 +38,7 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
|
|
38 |
[![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
|
39 |
"""
|
40 |
|
41 |
-
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3"]
|
42 |
|
43 |
|
44 |
if __name__ == "__main__":
|
@@ -46,10 +46,11 @@ if __name__ == "__main__":
|
|
46 |
|
47 |
def load_video_and_dataset(_env_name):
|
48 |
env_name = _env_name
|
|
|
|
|
|
|
49 |
|
50 |
-
|
51 |
-
gr.Warning("We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework.")
|
52 |
-
dataset_path, video_path = generate_dataset_from_expert("ppo", _env_name, 15, 3)
|
53 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
54 |
"dataset_path": dataset_path,
|
55 |
"ipe": None,
|
@@ -126,7 +127,7 @@ if __name__ == "__main__":
|
|
126 |
|
127 |
with gr.Row():
|
128 |
with gr.Column():
|
129 |
-
gr.Markdown("### Pretrained policy loading (PPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
|
130 |
choice = gr.Dropdown(envs, label="Environment name")
|
131 |
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
|
132 |
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
|
|
|
38 |
[![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
|
39 |
"""
|
40 |
|
41 |
+
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
|
42 |
|
43 |
|
44 |
if __name__ == "__main__":
|
|
|
46 |
|
47 |
def load_video_and_dataset(_env_name):
|
48 |
env_name = _env_name
|
49 |
+
agent = "ppo"
|
50 |
+
if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
|
51 |
+
agent ="trpo"
|
52 |
|
53 |
+
dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
|
|
|
|
|
54 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
55 |
"dataset_path": dataset_path,
|
56 |
"ipe": None,
|
|
|
127 |
|
128 |
with gr.Row():
|
129 |
with gr.Column():
|
130 |
+
gr.Markdown("### Pretrained policy loading (PPO or TRPO from [rl-baselines3-zoo](https://github.com/DLR-RM/rl-baselines3-zoo))")
|
131 |
choice = gr.Dropdown(envs, label="Environment name")
|
132 |
expert_video = gr.Video(label="Expert policy video", interactive=False, autoplay=True)
|
133 |
kan_widths = gr.Textbox(value="2", label="Widths of the hidden layers of the KAN, separated by commas (e.g. `3,3`). Leave empty if there are no hidden layers.")
|