File size: 1,390 Bytes
94f77f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.nn as nn
import os
import numpy as np

from fastshap import FastSHAP
from opendataval.dataloader import DataFetcher

Array = np.ndarray

def load_model(dataset_name: str) -> FastSHAP:
    explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt"))
    imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt"))
    shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1))
    return shap

def onehot_to_location(y: Array) -> Array:
    res = np.zeros((y.shape[0]))
    for i in range(y.shape[0]):
        for j in range(len(y[i])):
            if y[i, j] > 0:
                res[i] = j
                break
    return res

def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
    fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
        train_size, valid_size
    )
    x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
    x_train = np.array(x_train.tolist())
    y_train = onehot_to_location(np.array(y_train.tolist()))
    x_val = np.array(x_val.tolist())
    y_val = onehot_to_location(np.array(y_val.tolist()))
    return x_train, y_train, x_val, y_val

fastshap = load_model("electricity")
x_train, y_train, *_ = load_dataset("electricity", 100, 0)
print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])]))