|
import torch |
|
from typing import Dict, Tuple, Optional, Callable, Union |
|
import gymnasium as gym |
|
from kan import KAN |
|
import numpy as np |
|
|
|
|
|
def extract_dim(space: gym.Space): |
|
if isinstance(space, gym.spaces.Box) and len(space.shape) == 1: |
|
return space.shape[0], False |
|
elif isinstance(space, gym.spaces.Discrete): |
|
return space.n, True |
|
else: |
|
raise NotImplementedError(f"There is no support for space {space}.") |
|
|
|
|
|
class InterpretablePolicyExtractor: |
|
lib = ['x', 'x^2', 'x^3', 'x^4', 'exp', 'log', 'sqrt', 'tanh', 'sin', 'abs'] |
|
|
|
def __init__(self, env_name: str, hidden_widths: Optional[Tuple[int]]=None): |
|
self.env = gym.make(env_name) |
|
if hidden_widths is None: |
|
hidden_widths = [] |
|
observation_dim, self._observation_is_discrete = extract_dim(self.env.observation_space) |
|
action_dim, self._action_is_discrete = extract_dim(self.env.action_space) |
|
self.policy = KAN(width=[observation_dim, *hidden_widths, action_dim]) |
|
self.loss_fn = torch.nn.MSELoss() if not self._action_is_discrete else torch.nn.CrossEntropyLoss() |
|
|
|
def train_from_dataset(self, dataset: Union[Dict[str, torch.Tensor], str], steps: int = 20): |
|
if isinstance(dataset, str): |
|
dataset = torch.load(dataset) |
|
if dataset["train_label"].ndim == 1 and not self._action_is_discrete: |
|
dataset["train_label"] = dataset["train_label"][:, None] |
|
if dataset["train_label"].ndim == 1 and not self._action_is_discrete: |
|
dataset["test_label"] = dataset["test_label"][:, None] |
|
return self.policy.train(dataset, opt="LBFGS", steps=steps, loss_fn=self.loss_fn) |
|
|
|
def forward(self, observation): |
|
observation = torch.from_numpy(observation) |
|
action = self.policy(observation.unsqueeze(0)) |
|
if self._action_is_discrete: |
|
return action.argmax(axis=-1).squeeze().item() |
|
else: |
|
return action.squeeze(0).detach().numpy() |
|
|
|
def train_from_policy(self, policy: Callable[[np.ndarray], Union[np.ndarray, int, float]], steps: int): |
|
raise NotImplementedError() |
|
|