ShapleyValuePrediction / use_fastshap.py
Percy98's picture
add models and scripts
94f77f5 verified
import torch
import torch.nn as nn
import os
import numpy as np
from fastshap import FastSHAP
from opendataval.dataloader import DataFetcher
Array = np.ndarray
def load_model(dataset_name: str) -> FastSHAP:
explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt"))
imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt"))
shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1))
return shap
def onehot_to_location(y: Array) -> Array:
res = np.zeros((y.shape[0]))
for i in range(y.shape[0]):
for j in range(len(y[i])):
if y[i, j] > 0:
res[i] = j
break
return res
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
train_size, valid_size
)
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
x_train = np.array(x_train.tolist())
y_train = onehot_to_location(np.array(y_train.tolist()))
x_val = np.array(x_val.tolist())
y_val = onehot_to_location(np.array(y_val.tolist()))
return x_train, y_train, x_val, y_val
fastshap = load_model("electricity")
x_train, y_train, *_ = load_dataset("electricity", 100, 0)
print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])]))