add models and scripts
Browse files- adv/2dplanes/best_checkpoint.pt +3 -0
- adv/MiniBooNE/best_checkpoint.pt +3 -0
- adv/electricity/best_checkpoint.pt +3 -0
- adv/fried/best_checkpoint.pt +3 -0
- adv/nomao/best_checkpoint.pt +3 -0
- adv/pol/best_checkpoint.pt +3 -0
- fastshap/2dplanes/explainer.pt +3 -0
- fastshap/2dplanes/imputer.pt +3 -0
- fastshap/MiniBooNE/explainer.pt +3 -0
- fastshap/MiniBooNE/imputer.pt +3 -0
- fastshap/electricity/explainer.pt +3 -0
- fastshap/electricity/imputer.pt +3 -0
- fastshap/fried/explainer.pt +3 -0
- fastshap/fried/imputer.pt +3 -0
- fastshap/nomao/explainer.pt +3 -0
- fastshap/nomao/imputer.pt +3 -0
- fastshap/pol/explainer.pt +3 -0
- fastshap/pol/imputer.pt +3 -0
- use_adv.py +56 -0
- use_fastshap.py +39 -0
adv/2dplanes/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:50550fb39b33f5b915793c3bb5ef65221cacf0231eb58383293b778626c18149
|
3 |
+
size 292683
|
adv/MiniBooNE/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d4651ec0a6994d39c0615c3a914a3e1953f8bf916ddce5eaf34103f5ef8df846
|
3 |
+
size 333643
|
adv/electricity/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:24a51559d8b04472596dc323f0fa12b9d9bc6157f673fc0b2143032c8095159e
|
3 |
+
size 288587
|
adv/fried/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a7acd9b0fdbaefc284b3973e8d2e317a9838f038e22b753c0cdef7fb2128cf1e
|
3 |
+
size 292683
|
adv/nomao/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:114cb75c22580e8a731eafa11cf67fb721022748500e389763bd67d6de49e08e
|
3 |
+
size 373579
|
adv/pol/best_checkpoint.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:229d38de571687b49de6ce6567d280cbd40a2d469b657463d8cf7238d4ad7913
|
3 |
+
size 331595
|
fastshap/2dplanes/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d972175f11b885ca3643b14afc2e42b00284a8164ac02157901e73b60d2ec3ee
|
3 |
+
size 86108
|
fastshap/2dplanes/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c1a51cdc977bd70242b643051ec22e4e4e9b16ac31d44cba52cd4075a0df88da
|
3 |
+
size 82120
|
fastshap/MiniBooNE/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f5b6f3116ce7073f36e35b9515d6bd5fb2feff38013264127e435a3eed542cac
|
3 |
+
size 147868
|
fastshap/MiniBooNE/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:65b6789c8d86f1247ec8ecab4edc6db4e3225ab9d2ce472adff6984e746852b2
|
3 |
+
size 123272
|
fastshap/electricity/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:730a7c6d948db8617b336958a26eb67cd78cc791faef92313324c4e6909c97b3
|
3 |
+
size 79900
|
fastshap/electricity/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9f3031d87212634d28af9fb43c1d7b71a2c9d8c51a9364d28d03c0b4accf38eb
|
3 |
+
size 78024
|
fastshap/fried/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:372c92f480bcaf28595c17e7608a53b38eebc82310171af0a7be95a8dba28ea6
|
3 |
+
size 86108
|
fastshap/fried/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:1965b76e191feff4b2c24a42d0f351a0616de0e77aa45aa18098e7039ab3caf3
|
3 |
+
size 82056
|
fastshap/nomao/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:43b7e5290a2973240082853ba9f195949e7be96cef0d784b5c432941c6bda1ae
|
3 |
+
size 208092
|
fastshap/nomao/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:9a2497dd4411343026cdab089ef238a27c988b36143cdfe1c9274de98e7279e5
|
3 |
+
size 162952
|
fastshap/pol/explainer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:e602a98b32f2d4066b60b709444031a4ffd3cdc4bbee4281b4fbd8e810374280
|
3 |
+
size 144796
|
fastshap/pol/imputer.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:34bcd404134aab210b84849c6bace214b7f1577c3082b3ef7d7d3c046c8e7627
|
3 |
+
size 120968
|
use_adv.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import lightning as pl
|
6 |
+
from torch.utils.data import TensorDataset, DataLoader
|
7 |
+
from adv import ValuationModel
|
8 |
+
from opendataval.dataloader import DataFetcher
|
9 |
+
|
10 |
+
Array = np.ndarray
|
11 |
+
|
12 |
+
def load_model(dataset_name: str) -> ValuationModel:
|
13 |
+
results = torch.load(
|
14 |
+
os.path.join("adv", dataset_name, "best_checkpoint.pt")
|
15 |
+
)
|
16 |
+
return results
|
17 |
+
|
18 |
+
def onehot_to_location(y: Array) -> Array:
|
19 |
+
res = np.zeros((y.shape[0]), dtype=int)
|
20 |
+
for i in range(y.shape[0]):
|
21 |
+
for j in range(len(y[i])):
|
22 |
+
if y[i, j] > 0:
|
23 |
+
res[i] = j
|
24 |
+
break
|
25 |
+
return res
|
26 |
+
|
27 |
+
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
|
28 |
+
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
|
29 |
+
train_size, valid_size
|
30 |
+
)
|
31 |
+
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
|
32 |
+
x_train = np.array(x_train.tolist())
|
33 |
+
y_train = onehot_to_location(np.array(y_train.tolist()))
|
34 |
+
x_val = np.array(x_val.tolist())
|
35 |
+
y_val = onehot_to_location(np.array(y_val.tolist()))
|
36 |
+
return x_train, y_train, x_val, y_val
|
37 |
+
|
38 |
+
checkpoint = load_model("electricity")
|
39 |
+
x_train, y_train, *_ = load_dataset("electricity", 100, 0)
|
40 |
+
tensor_x = torch.Tensor([x_train[24]])
|
41 |
+
tensor_y = torch.Tensor([y_train[24]])
|
42 |
+
tensor_y = tensor_y.long()
|
43 |
+
|
44 |
+
dataset = TensorDataset(tensor_x, tensor_y)
|
45 |
+
dataloader = DataLoader(dataset, shuffle=False, drop_last=False, num_workers=0)
|
46 |
+
trainer = pl.Trainer(
|
47 |
+
precision="bf16-mixed", accelerator="gpu", devices=[0], enable_progress_bar=False
|
48 |
+
)
|
49 |
+
estimates = (
|
50 |
+
torch.cat(trainer.predict(checkpoint, dataloader))
|
51 |
+
.squeeze()
|
52 |
+
.cpu()
|
53 |
+
.float()
|
54 |
+
)
|
55 |
+
|
56 |
+
print(estimates.numpy())
|
use_fastshap.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
from fastshap import FastSHAP
|
7 |
+
from opendataval.dataloader import DataFetcher
|
8 |
+
|
9 |
+
Array = np.ndarray
|
10 |
+
|
11 |
+
def load_model(dataset_name: str) -> FastSHAP:
|
12 |
+
explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt"))
|
13 |
+
imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt"))
|
14 |
+
shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1))
|
15 |
+
return shap
|
16 |
+
|
17 |
+
def onehot_to_location(y: Array) -> Array:
|
18 |
+
res = np.zeros((y.shape[0]))
|
19 |
+
for i in range(y.shape[0]):
|
20 |
+
for j in range(len(y[i])):
|
21 |
+
if y[i, j] > 0:
|
22 |
+
res[i] = j
|
23 |
+
break
|
24 |
+
return res
|
25 |
+
|
26 |
+
def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
|
27 |
+
fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
|
28 |
+
train_size, valid_size
|
29 |
+
)
|
30 |
+
x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
|
31 |
+
x_train = np.array(x_train.tolist())
|
32 |
+
y_train = onehot_to_location(np.array(y_train.tolist()))
|
33 |
+
x_val = np.array(x_val.tolist())
|
34 |
+
y_val = onehot_to_location(np.array(y_val.tolist()))
|
35 |
+
return x_train, y_train, x_val, y_val
|
36 |
+
|
37 |
+
fastshap = load_model("electricity")
|
38 |
+
x_train, y_train, *_ = load_dataset("electricity", 100, 0)
|
39 |
+
print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])]))
|