nmed2024 / adrd /shap /base.py
xf3227's picture
ok
6fc43ab
raw
history blame
2 kB
from abc import ABC
from abc import abstractmethod
from typing import Any, Type
from functools import wraps
from torch.utils.data import DataLoader
from torch import set_grad_enabled
import torch
Tensor = Type[torch.Tensor]
from ..utils.misc import convert_args_kwargs_to_kwargs
from ..utils import TransformerTestingDataset
from ..model import ADRDModel
class BaseExplainer:
""" ... """
def __init__(self, model: ADRDModel) -> None:
""" ... """
self.model = model
def shap_values(self,
x,
is_embedding: dict[str, bool] | None = None,
):
""" ... """
# result placeholder
phi = [
{
tgt_k: {
src_k: 0.0 for src_k in self.model.src_modalities
} for tgt_k in self.model.tgt_modalities
}
]
# set nn to eval mode
set_grad_enabled(False)
self.model.net_.eval()
# initialize dataset and dataloader object
dat = TransformerTestingDataset(x, self.model.src_modalities, is_embedding)
ldr = DataLoader(
dataset = dat,
batch_size = 1,
shuffle = False,
drop_last = False,
num_workers = 0,
collate_fn = TransformerTestingDataset.collate_fn,
)
# loop through instances and compute shap values
for idx, (smp, mask) in enumerate(ldr):
mask_flat = torch.concatenate(list(mask.values()))
if torch.logical_not(mask_flat).sum().item() == 0:
pass
elif torch.logical_not(mask_flat).sum().item() == 1:
pass
else:
self._shap_values_core(smp, mask, phi[idx], is_embedding)
return phi
@abstractmethod
def _shap_values_core(self,
smp: dict[str, Tensor],
mask: dict[str, Tensor],
phi_: dict[str, dict[str, float]],
):
""" To implement different algorithms. """
pass