File size: 2,264 Bytes
ca85408 b471ab8 ca85408 934779e ca85408 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
import torch
from typing import Dict, Tuple, Optional, Callable, Union
import gymnasium as gym
from kan import KAN
import numpy as np
def extract_dim(space: gym.Space):
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1:
return space.shape[0], False
elif isinstance(space, gym.spaces.Discrete):
return space.n, True
else:
raise NotImplementedError(f"There is no support for space {space}.")
class InterpretablePolicyExtractor:
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs']
def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None):
self.env = gym.make(env_name)
if hidden_widths is None:
hidden_widths = []
observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space)
action_dim, self._action_is_discrete = extract_dim(self.env.action_space)
self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim])
self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss()
def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20):
if isinstance(dataset, str):
dataset = torch.load(dataset)
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
dataset["train_label"] = dataset["train_label"][:, None]
if dataset["train_label"].ndim == 1 and not self._action_is_discrete:
dataset["test_label"] = dataset["test_label"][:, None]
dataset["train_input"] = dataset["train_input"].float()
dataset["test_input"] = dataset["test_input"].float()
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn)
def forward(self, observation):
observation = torch.from_numpy(observation).float()
action = self.policy(observation.unsqueeze(0))
if self._action_is_discrete:
return action.argmax(axis=-1).squeeze().item()
else:
return action.squeeze(0).detach().numpy()
def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int):
raise NotImplementedError() # TODO
|