xf3227's picture
ok
6fc43ab
raw
history blame
3.31 kB
__all__ = ['MCExplainer']
from . import BaseExplainer
from typing import Any, Type
from torch import set_grad_enabled
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import torch
import numpy as np
from tqdm import tqdm
Tensor = Type[torch.Tensor]
NUM_PERMUTATIONS = 1024
BATCH_SIZE = NUM_PERMUTATIONS
class MCExplainer(BaseExplainer):
def __init__(self,
model: Any,
):
""" ... """
super().__init__(model)
def _shap_values_core(self,
smp: dict[str, Tensor],
mask: dict[str, Tensor],
phi_: dict[str, dict[str, float]],
is_embedding: dict[str, bool] | None = None,
):
""" ... """
# get the list of available feature names
avail = [k for k in mask if mask[k].item() == False]
# repeat feature dict and mount to device
smps = dict()
for k, v in smp.items():
if len(v.shape) == 1:
smps[k] = smp[k].expand(NUM_PERMUTATIONS)
elif len(v.shape) == 2:
smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1)
elif len(v.shape) == 3:
smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1, -1)
else:
raise ValueError
smps = {k: smps[k].to(self.model.device) for k in self.model.src_modalities}
# loop through available features
print('{} features to evaluate ...'.format(len(avail)))
for src_k in tqdm(avail):
# get features to uncover
to_uncover = []
for _ in range(NUM_PERMUTATIONS):
perm = avail.copy()
random.shuffle(perm)
to_uncover.append(perm[:perm.index(src_k)])
# construct masks without src_k
masks_wo_src_k = {k: np.ones(NUM_PERMUTATIONS, dtype=np.bool_) for k in self.model.src_modalities}
for i, lst in enumerate(to_uncover):
for k in lst:
masks_wo_src_k[k][i] = False
# construct masks with src_k
masks_wi_src_k = masks_wo_src_k.copy()
masks_wi_src_k[src_k] = np.zeros(NUM_PERMUTATIONS, dtype=np.bool_)
# mount inputs to device
masks_wi_src_k = {k: torch.tensor(masks_wi_src_k[k], device=self.model.device) for k in self.model.src_modalities}
masks_wo_src_k = {k: torch.tensor(masks_wo_src_k[k], device=self.model.device) for k in self.model.src_modalities}
# run model
out_wi_src_k = self.model.net_(smps, masks_wi_src_k, is_embedding)
out_wo_src_k = self.model.net_(smps, masks_wo_src_k, is_embedding)
# to numpy
out_wi_src_k = {k: out_wi_src_k[k].cpu().numpy() for k in self.model.tgt_modalities}
out_wo_src_k = {k: out_wo_src_k[k].cpu().numpy() for k in self.model.tgt_modalities}
# replace nan with zeros when all input features are excluded
out_wo_src_k = {k: np.nan_to_num(out_wo_src_k[k]) for k in self.model.tgt_modalities}
# calculate shap values
mean = {k: (out_wi_src_k[k] - out_wo_src_k[k]).mean() for k in self.model.tgt_modalities}
for tgt_k in self.model.tgt_modalities:
phi_[tgt_k][src_k] = mean[tgt_k]