|
__all__ = ['MCExplainer'] |
|
|
|
from . import BaseExplainer |
|
from typing import Any, Type |
|
from torch import set_grad_enabled |
|
from torch.utils.data import Dataset |
|
from torch.utils.data import DataLoader |
|
import random |
|
import torch |
|
import numpy as np |
|
from tqdm import tqdm |
|
Tensor = Type[torch.Tensor] |
|
|
|
NUM_PERMUTATIONS = 1024 |
|
BATCH_SIZE = NUM_PERMUTATIONS |
|
|
|
class MCExplainer(BaseExplainer): |
|
|
|
def __init__(self, |
|
model: Any, |
|
): |
|
""" ... """ |
|
super().__init__(model) |
|
|
|
def _shap_values_core(self, |
|
smp: dict[str, Tensor], |
|
mask: dict[str, Tensor], |
|
phi_: dict[str, dict[str, float]], |
|
is_embedding: dict[str, bool] | None = None, |
|
): |
|
""" ... """ |
|
|
|
avail = [k for k in mask if mask[k].item() == False] |
|
|
|
|
|
smps = dict() |
|
for k, v in smp.items(): |
|
if len(v.shape) == 1: |
|
smps[k] = smp[k].expand(NUM_PERMUTATIONS) |
|
elif len(v.shape) == 2: |
|
smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1) |
|
elif len(v.shape) == 3: |
|
smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1, -1) |
|
else: |
|
raise ValueError |
|
smps = {k: smps[k].to(self.model.device) for k in self.model.src_modalities} |
|
|
|
|
|
print('{} features to evaluate ...'.format(len(avail))) |
|
for src_k in tqdm(avail): |
|
|
|
to_uncover = [] |
|
for _ in range(NUM_PERMUTATIONS): |
|
perm = avail.copy() |
|
random.shuffle(perm) |
|
to_uncover.append(perm[:perm.index(src_k)]) |
|
|
|
|
|
masks_wo_src_k = {k: np.ones(NUM_PERMUTATIONS, dtype=np.bool_) for k in self.model.src_modalities} |
|
for i, lst in enumerate(to_uncover): |
|
for k in lst: |
|
masks_wo_src_k[k][i] = False |
|
|
|
|
|
masks_wi_src_k = masks_wo_src_k.copy() |
|
masks_wi_src_k[src_k] = np.zeros(NUM_PERMUTATIONS, dtype=np.bool_) |
|
|
|
|
|
masks_wi_src_k = {k: torch.tensor(masks_wi_src_k[k], device=self.model.device) for k in self.model.src_modalities} |
|
masks_wo_src_k = {k: torch.tensor(masks_wo_src_k[k], device=self.model.device) for k in self.model.src_modalities} |
|
|
|
|
|
out_wi_src_k = self.model.net_(smps, masks_wi_src_k, is_embedding) |
|
out_wo_src_k = self.model.net_(smps, masks_wo_src_k, is_embedding) |
|
|
|
|
|
out_wi_src_k = {k: out_wi_src_k[k].cpu().numpy() for k in self.model.tgt_modalities} |
|
out_wo_src_k = {k: out_wo_src_k[k].cpu().numpy() for k in self.model.tgt_modalities} |
|
|
|
|
|
out_wo_src_k = {k: np.nan_to_num(out_wo_src_k[k]) for k in self.model.tgt_modalities} |
|
|
|
|
|
mean = {k: (out_wi_src_k[k] - out_wo_src_k[k]).mean() for k in self.model.tgt_modalities} |
|
for tgt_k in self.model.tgt_modalities: |
|
phi_[tgt_k][src_k] = mean[tgt_k] |