Try to debug
Browse files- app.py +2 -6
- interpretable.py +1 -3
app.py
CHANGED
@@ -37,7 +37,7 @@ To follow the progress of KAN in RL you can check the repo [kanrl](https://githu
|
|
37 |
|
38 |
[![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
|
39 |
|
40 |
-
*Please be patient, as the process may take a few minutes to run, especially in environments with large state/action spaces or with a complex KAN architecture.*
|
41 |
"""
|
42 |
|
43 |
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
|
@@ -48,13 +48,9 @@ if __name__ == "__main__":
|
|
48 |
|
49 |
def load_video_and_dataset(_env_name):
|
50 |
env_name = _env_name
|
51 |
-
if env_name in ["Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]:
|
52 |
-
gr.Warning(
|
53 |
-
"We're currently in the process of adding support for Mujoco environments, so the application may encounter crashes during this phase. We encourage contributors to join us in the repository https://github.com/riiswa/kanrl to assist in the development and support of other environments. Your contributions are invaluable in ensuring a robust and comprehensive framework."
|
54 |
-
)
|
55 |
agent = "ppo"
|
56 |
if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
|
57 |
-
agent ="trpo"
|
58 |
|
59 |
dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
|
60 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
|
|
37 |
|
38 |
[![riiswa/kanrl - GitHub](https://gh-card.dev/repos/riiswa/kanrl.svg)](https://github.com/riiswa/kanrl)
|
39 |
|
40 |
+
*Please be patient, as the process may take a few minutes to run, especially in environments with large state/action spaces or with a complex KAN architecture. For optimal performance, default parameters may not suffice. Feel free to experiment with different settings to achieve desired results.*
|
41 |
"""
|
42 |
|
43 |
envs = ["CartPole-v1", "MountainCar-v0", "Acrobot-v1", "Pendulum-v1", "MountainCarContinuous-v0", "LunarLander-v2", "Swimmer-v3", "Hopper-v3", "HalfCheetah-v3", "Walker2d-v3"]
|
|
|
48 |
|
49 |
def load_video_and_dataset(_env_name):
|
50 |
env_name = _env_name
|
|
|
|
|
|
|
|
|
51 |
agent = "ppo"
|
52 |
if env_name == "Swimmer-v3" or env_name == "Walker2d-v3":
|
53 |
+
agent = "trpo"
|
54 |
|
55 |
dataset_path, video_path = generate_dataset_from_expert(agent, _env_name, 15, 3)
|
56 |
return video_path, gr.Button("Compute the symbolic policy!", interactive=True), {
|
interpretable.py
CHANGED
@@ -35,12 +35,10 @@ class InterpretablePolicyExtractor:
|
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
36 |
dataset["train_input"] = dataset["train_input"].float()
|
37 |
dataset["test_input"] = dataset["test_input"].float()
|
38 |
-
for k,v in dataset.items():
|
39 |
-
print(k, v.shape, v.dtype)
|
40 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
41 |
|
42 |
def forward(self, observation):
|
43 |
-
observation = torch.from_numpy(observation)
|
44 |
action = self.policy(observation.unsqueeze(0))
|
45 |
if self._action_is_discrete:
|
46 |
return action.argmax(axis=-1).squeeze().item()
|
|
|
35 |
dataset["test_label"] = dataset["test_label"][:, None]
|
36 |
dataset["train_input"] = dataset["train_input"].float()
|
37 |
dataset["test_input"] = dataset["test_input"].float()
|
|
|
|
|
38 |
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
|
39 |
|
40 |
def forward(self, observation):
|
41 |
+
observation = torch.from_numpy(observation).float()
|
42 |
action = self.policy(observation.unsqueeze(0))
|
43 |
if self._action_is_discrete:
|
44 |
return action.argmax(axis=-1).squeeze().item()
|