import torch import torch.nn as nn import os import numpy as np from fastshap import FastSHAP from opendataval.dataloader import DataFetcher Array = np.ndarray def load_model(dataset_name: str) -> FastSHAP: explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt")) imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt")) shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1)) return shap def onehot_to_location(y: Array) -> Array: res = np.zeros((y.shape[0])) for i in range(y.shape[0]): for j in range(len(y[i])): if y[i, j] > 0: res[i] = j break return res def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]: fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count( train_size, valid_size ) x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints x_train = np.array(x_train.tolist()) y_train = onehot_to_location(np.array(y_train.tolist())) x_val = np.array(x_val.tolist()) y_val = onehot_to_location(np.array(y_val.tolist())) return x_train, y_train, x_val, y_val fastshap = load_model("electricity") x_train, y_train, *_ = load_dataset("electricity", 100, 0) print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])]))