Percy98 commited on
Commit
94f77f5
1 Parent(s): 0b32a90

add models and scripts

Browse files
adv/2dplanes/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:50550fb39b33f5b915793c3bb5ef65221cacf0231eb58383293b778626c18149
3
+ size 292683
adv/MiniBooNE/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4651ec0a6994d39c0615c3a914a3e1953f8bf916ddce5eaf34103f5ef8df846
3
+ size 333643
adv/electricity/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:24a51559d8b04472596dc323f0fa12b9d9bc6157f673fc0b2143032c8095159e
3
+ size 288587
adv/fried/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a7acd9b0fdbaefc284b3973e8d2e317a9838f038e22b753c0cdef7fb2128cf1e
3
+ size 292683
adv/nomao/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:114cb75c22580e8a731eafa11cf67fb721022748500e389763bd67d6de49e08e
3
+ size 373579
adv/pol/best_checkpoint.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:229d38de571687b49de6ce6567d280cbd40a2d469b657463d8cf7238d4ad7913
3
+ size 331595
fastshap/2dplanes/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d972175f11b885ca3643b14afc2e42b00284a8164ac02157901e73b60d2ec3ee
3
+ size 86108
fastshap/2dplanes/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c1a51cdc977bd70242b643051ec22e4e4e9b16ac31d44cba52cd4075a0df88da
3
+ size 82120
fastshap/MiniBooNE/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f5b6f3116ce7073f36e35b9515d6bd5fb2feff38013264127e435a3eed542cac
3
+ size 147868
fastshap/MiniBooNE/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:65b6789c8d86f1247ec8ecab4edc6db4e3225ab9d2ce472adff6984e746852b2
3
+ size 123272
fastshap/electricity/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:730a7c6d948db8617b336958a26eb67cd78cc791faef92313324c4e6909c97b3
3
+ size 79900
fastshap/electricity/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f3031d87212634d28af9fb43c1d7b71a2c9d8c51a9364d28d03c0b4accf38eb
3
+ size 78024
fastshap/fried/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:372c92f480bcaf28595c17e7608a53b38eebc82310171af0a7be95a8dba28ea6
3
+ size 86108
fastshap/fried/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1965b76e191feff4b2c24a42d0f351a0616de0e77aa45aa18098e7039ab3caf3
3
+ size 82056
fastshap/nomao/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:43b7e5290a2973240082853ba9f195949e7be96cef0d784b5c432941c6bda1ae
3
+ size 208092
fastshap/nomao/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9a2497dd4411343026cdab089ef238a27c988b36143cdfe1c9274de98e7279e5
3
+ size 162952
fastshap/pol/explainer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e602a98b32f2d4066b60b709444031a4ffd3cdc4bbee4281b4fbd8e810374280
3
+ size 144796
fastshap/pol/imputer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:34bcd404134aab210b84849c6bace214b7f1577c3082b3ef7d7d3c046c8e7627
3
+ size 120968
use_adv.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ import numpy as np
4
+
5
+ import lightning as pl
6
+ from torch.utils.data import TensorDataset, DataLoader
7
+ from adv import ValuationModel
8
+ from opendataval.dataloader import DataFetcher
9
+
10
+ Array = np.ndarray
11
+
12
+ def load_model(dataset_name: str) -> ValuationModel:
13
+ results = torch.load(
14
+ os.path.join("adv", dataset_name, "best_checkpoint.pt")
15
+ )
16
+ return results
17
+
18
+ def onehot_to_location(y: Array) -> Array:
19
+ res = np.zeros((y.shape[0]), dtype=int)
20
+ for i in range(y.shape[0]):
21
+ for j in range(len(y[i])):
22
+ if y[i, j] > 0:
23
+ res[i] = j
24
+ break
25
+ return res
26
+
27
+ def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
28
+ fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
29
+ train_size, valid_size
30
+ )
31
+ x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
32
+ x_train = np.array(x_train.tolist())
33
+ y_train = onehot_to_location(np.array(y_train.tolist()))
34
+ x_val = np.array(x_val.tolist())
35
+ y_val = onehot_to_location(np.array(y_val.tolist()))
36
+ return x_train, y_train, x_val, y_val
37
+
38
+ checkpoint = load_model("electricity")
39
+ x_train, y_train, *_ = load_dataset("electricity", 100, 0)
40
+ tensor_x = torch.Tensor([x_train[24]])
41
+ tensor_y = torch.Tensor([y_train[24]])
42
+ tensor_y = tensor_y.long()
43
+
44
+ dataset = TensorDataset(tensor_x, tensor_y)
45
+ dataloader = DataLoader(dataset, shuffle=False, drop_last=False, num_workers=0)
46
+ trainer = pl.Trainer(
47
+ precision="bf16-mixed", accelerator="gpu", devices=[0], enable_progress_bar=False
48
+ )
49
+ estimates = (
50
+ torch.cat(trainer.predict(checkpoint, dataloader))
51
+ .squeeze()
52
+ .cpu()
53
+ .float()
54
+ )
55
+
56
+ print(estimates.numpy())
use_fastshap.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import os
4
+ import numpy as np
5
+
6
+ from fastshap import FastSHAP
7
+ from opendataval.dataloader import DataFetcher
8
+
9
+ Array = np.ndarray
10
+
11
+ def load_model(dataset_name: str) -> FastSHAP:
12
+ explainer = torch.load(os.path.join("fastshap", dataset_name, "explainer.pt"))
13
+ imputer = torch.load(os.path.join("fastshap", dataset_name, "imputer.pt"))
14
+ shap = FastSHAP(explainer, imputer, link=nn.Softmax(dim=-1))
15
+ return shap
16
+
17
+ def onehot_to_location(y: Array) -> Array:
18
+ res = np.zeros((y.shape[0]))
19
+ for i in range(y.shape[0]):
20
+ for j in range(len(y[i])):
21
+ if y[i, j] > 0:
22
+ res[i] = j
23
+ break
24
+ return res
25
+
26
+ def load_dataset(dataset_name: str, train_size: int, valid_size: int) -> tuple[Array, Array, Array, Array]:
27
+ fetcher = DataFetcher(dataset_name=dataset_name).split_dataset_by_count(
28
+ train_size, valid_size
29
+ )
30
+ x_train, y_train, x_val, y_val, _, _ = fetcher.datapoints
31
+ x_train = np.array(x_train.tolist())
32
+ y_train = onehot_to_location(np.array(y_train.tolist()))
33
+ x_val = np.array(x_val.tolist())
34
+ y_val = onehot_to_location(np.array(y_val.tolist()))
35
+ return x_train, y_train, x_val, y_val
36
+
37
+ fastshap = load_model("electricity")
38
+ x_train, y_train, *_ = load_dataset("electricity", 100, 0)
39
+ print(np.sum(fastshap.shap_values(np.array([x_train[24]]))[:, int(y_train[24])]))