riiswa commited on
Commit
72c067e
·
1 Parent(s): e27178f

Try to fix type error

Browse files
Files changed (1) hide show
  1. app.py +8 -1
app.py CHANGED
@@ -33,7 +33,7 @@ In this demo, we showcase a method to make a trained Reinforcement Learning (RL)
33
  - Extract an interpretable policy expressed in symbolic form.
34
 
35
  For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
36
- To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl).
37
 
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
  """
@@ -46,6 +46,13 @@ if __name__ == "__main__":
46
 
47
  def load_video_and_dataset(_env_name):
48
  env_name = _env_name
 
 
 
 
 
 
 
49
  agent = "ppo"
50
  if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
51
  agent ="trpo"
 
33
  - Extract an interpretable policy expressed in symbolic form.
34
 
35
  For more information about KAN you can read the [paper](https://arxiv.org/abs/2404.19756), and check the [PyTorch official information](https://github.com/KindXiaoming/pykan).
36
+ To follow the progress of KAN in RL you can check the repo [kanrl](https://github.com/riiswa/kanrl) (you can run this app locally).
37
 
38
  [![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
39
  """
 
46
 
47
  def load_video_and_dataset(_env_name):
48
  env_name = _env_name
49
+ if env_name in ["Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]:
50
+ gr.Warning(
51
+ "We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework."
52
+ )
53
+ torch.set_default_dtype(torch.float64)
54
+ else:
55
+ torch.set_default_dtype(torch.float32)
56
  agent = "ppo"
57
  if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
58
  agent ="trpo"