|
import torch |
|
import torch.nn as nn |
|
import os |
|
import numpy as np |
|
|
|
from fastshap import FastSHAP |
|
from opendataval.dataloader import DataFetcher |
|
|
|
Array = np.ndarray |
|
|
|
def load_model(dataset_name: str) -> FastSHAP: |
|
explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt")) |
|
imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt")) |
|
shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1)) |
|
return shap |
|
|
|
def onehot_to_location(y: Array) -> Array: |
|
res = np.zeros((y.shape[0])) |
|
for i in range(y.shape[0]): |
|
for j in range(len(y[i])): |
|
if y[i, j] > 0: |
|
res[i] = j |
|
break |
|
return res |
|
|
|
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]: |
|
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count( |
|
train_size, valid_size |
|
) |
|
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints |
|
x_train = np.array(x_train.tolist()) |
|
y_train = onehot_to_location(np.array(y_train.tolist())) |
|
x_val = np.array(x_val.tolist()) |
|
y_val = onehot_to_location(np.array(y_val.tolist())) |
|
return x_train, y_train, x_val, y_val |
|
|
|
fastshap = load_model("electricity") |
|
x_train, y_train, *_ = load_dataset("electricity", 100, 0) |
|
print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])])) |