File size: 3,309 Bytes
6fc43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
__all__ = ['MCExplainer']

from . import BaseExplainer
from typing import Any, Type
from torch import set_grad_enabled
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import random
import torch
import numpy as np
from tqdm import tqdm
Tensor = Type[torch.Tensor]

NUM_PERMUTATIONS = 1024
BATCH_SIZE = NUM_PERMUTATIONS

class MCExplainer(BaseExplainer):

    def __init__(self,
        model: Any,
    ):
        """ ... """
        super().__init__(model)

    def _shap_values_core(self,
        smp: dict[str, Tensor],
        mask: dict[str, Tensor],
        phi_: dict[str, dict[str, float]],
        is_embedding: dict[str, bool] | None = None,
    ):
        """ ... """
        # get the list of available feature names
        avail = [k for k in mask if mask[k].item() == False]

        # repeat feature dict and mount to device
        smps = dict()
        for k, v in smp.items():
            if len(v.shape) == 1:
                smps[k] = smp[k].expand(NUM_PERMUTATIONS)
            elif len(v.shape) == 2:
                smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1)
            elif len(v.shape) == 3:
                smps[k] = smp[k].expand(NUM_PERMUTATIONS, -1, -1)
            else:
                raise ValueError
        smps = {k: smps[k].to(self.model.device) for k in self.model.src_modalities}
        
        # loop through available features
        print('{} features to evaluate ...'.format(len(avail)))
        for src_k in tqdm(avail):
            # get features to uncover
            to_uncover = []
            for _ in range(NUM_PERMUTATIONS):
                perm = avail.copy()
                random.shuffle(perm)
                to_uncover.append(perm[:perm.index(src_k)])

            # construct masks without src_k
            masks_wo_src_k = {k: np.ones(NUM_PERMUTATIONS, dtype=np.bool_) for k in self.model.src_modalities}
            for i, lst in enumerate(to_uncover):
                for k in lst:
                    masks_wo_src_k[k][i] = False

            # construct masks with src_k
            masks_wi_src_k = masks_wo_src_k.copy()
            masks_wi_src_k[src_k] = np.zeros(NUM_PERMUTATIONS, dtype=np.bool_)

            # mount inputs to device
            masks_wi_src_k = {k: torch.tensor(masks_wi_src_k[k], device=self.model.device) for k in self.model.src_modalities}
            masks_wo_src_k = {k: torch.tensor(masks_wo_src_k[k], device=self.model.device) for k in self.model.src_modalities}

            # run model
            out_wi_src_k = self.model.net_(smps, masks_wi_src_k, is_embedding)
            out_wo_src_k = self.model.net_(smps, masks_wo_src_k, is_embedding)

            # to numpy
            out_wi_src_k = {k: out_wi_src_k[k].cpu().numpy() for k in self.model.tgt_modalities}
            out_wo_src_k = {k: out_wo_src_k[k].cpu().numpy() for k in self.model.tgt_modalities}

            # replace nan with zeros when all input features are excluded
            out_wo_src_k = {k: np.nan_to_num(out_wo_src_k[k]) for k in self.model.tgt_modalities}

            # calculate shap values
            mean = {k: (out_wi_src_k[k] - out_wo_src_k[k]).mean() for k in self.model.tgt_modalities}   
            for tgt_k in self.model.tgt_modalities:
                phi_[tgt_k][src_k] = mean[tgt_k]