Try to fix type error
Browse files
app.py
CHANGED
@@ -33,7 +33,7 @@ In this demo, we showcase a method to make a trained Reinforcement Learning (RL)
|
|
33 |
- Extract an interpretable policy expressed in symbolic form.
|
34 |
|
35 |
For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
|
36 |
-
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
|
37 |
|
38 |
[](https://github.com/riiswa/kanrl)
|
39 |
"""
|
@@ -46,6 +46,13 @@ if __name__ == "__main__":
|
|
46 |
|
47 |
def load_video_and_dataset(_env_name):
|
48 |
env_name = _env_name
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
49 |
agent = "ppo"
|
50 |
if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
|
51 |
agent ="trpo"
|
|
|
33 |
- Extract an interpretable policy expressed in symbolic form.
|
34 |
|
35 |
For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
|
36 |
+
To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl) (you can run this app locally).
|
37 |
|
38 |
[](https://github.com/riiswa/kanrl)
|
39 |
"""
|
|
|
46 |
|
47 |
def load_video_and_dataset(_env_name):
|
48 |
env_name = _env_name
|
49 |
+
if env_name in ["Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]:
|
50 |
+
gr.Warning(
|
51 |
+
"We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework."
|
52 |
+
)
|
53 |
+
torch.set_default_dtype(torch.float64)
|
54 |
+
else:
|
55 |
+
torch.set_default_dtype(torch.float32)
|
56 |
agent = "ppo"
|
57 |
if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
|
58 |
agent ="trpo"
|