File size: 3,680 Bytes
cc9dfd7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from ..torch_core import *
from ..layers import *
from ..basic_data import *
from ..basic_train import *
from ..train import ClassificationInterpretation

__all__ = ['TabularModel']

class TabularModel(Module):
    "Basic model for tabular data."
    def __init__(self, emb_szs:ListSizes, n_cont:int, out_sz:int, layers:Collection[int], ps:Collection[float]=None,
                 emb_drop:float=0., y_range:OptRange=None, use_bn:bool=True, bn_final:bool=False):
        super().__init__()
        ps = ifnone(ps, [0]*len(layers))
        ps = listify(ps, layers)
        self.embeds = nn.ModuleList([embedding(ni, nf) for ni,nf in emb_szs])
        self.emb_drop = nn.Dropout(emb_drop)
        self.bn_cont = nn.BatchNorm1d(n_cont)
        n_emb = sum(e.embedding_dim for e in self.embeds)
        self.n_emb,self.n_cont,self.y_range = n_emb,n_cont,y_range
        sizes = self.get_sizes(layers, out_sz)
        actns = [nn.ReLU(inplace=True) for _ in range(len(sizes)-2)] + [None]
        layers = []
        for i,(n_in,n_out,dp,act) in enumerate(zip(sizes[:-1],sizes[1:],[0.]+ps,actns)):
            layers += bn_drop_lin(n_in, n_out, bn=use_bn and i!=0, p=dp, actn=act)
        if bn_final: layers.append(nn.BatchNorm1d(sizes[-1]))
        self.layers = nn.Sequential(*layers)

    def get_sizes(self, layers, out_sz):
        return [self.n_emb + self.n_cont] + layers + [out_sz]

    def forward(self, x_cat:Tensor, x_cont:Tensor) -> Tensor:
        if self.n_emb != 0:
            x = [e(x_cat[:,i]) for i,e in enumerate(self.embeds)]
            x = torch.cat(x, 1)
            x = self.emb_drop(x)
        if self.n_cont != 0:
            x_cont = self.bn_cont(x_cont)
            x = torch.cat([x, x_cont], 1) if self.n_emb != 0 else x_cont
        x = self.layers(x)
        if self.y_range is not None:
            x = (self.y_range[1]-self.y_range[0]) * torch.sigmoid(x) + self.y_range[0]
        return x

@classmethod
def _cl_int_from_learner(cls, learn:Learner, ds_type=DatasetType.Valid, activ:nn.Module=None):
    "Creates an instance of 'ClassificationInterpretation"
    preds = learn.get_preds(ds_type=ds_type, activ=activ, with_loss=True)
    return cls(learn, *preds, ds_type=ds_type)

def _cl_int_plot_top_losses(self, k, largest:bool=True, return_table:bool=False)->Optional[plt.Figure]:
    "Generates a dataframe of 'top_losses' along with their prediction, actual, loss, and probability of the actual class."
    tl_val, tl_idx = self.top_losses(k, largest)
    classes = self.data.classes
    cat_names = self.data.x.cat_names
    cont_names = self.data.x.cont_names
    df = pd.DataFrame(columns=[['Prediction', 'Actual', 'Loss', 'Probability'] + cat_names + cont_names])
    for i, idx in enumerate(tl_idx):
        da, cl = self.data.dl(self.ds_type).dataset[idx]
        cl = int(cl)
        t1 = str(da)
        t1 = t1.split(';')
        arr = []
        arr.extend([classes[self.pred_class[idx]], classes[cl], f'{self.losses[idx]:.2f}',
                    f'{self.preds[idx][cl]:.2f}'])
        for x in range(len(t1)-1):
            _, value = t1[x].rsplit(' ', 1)
            arr.append(value)
        df.loc[i] = arr
    display(df)
    return_fig = return_table
    if ifnone(return_fig, defaults.return_fig): return df


ClassificationInterpretation.from_learner = _cl_int_from_learner
ClassificationInterpretation.plot_top_losses = _cl_int_plot_top_losses

def _learner_interpret(learn:Learner, ds_type:DatasetType = DatasetType.Valid):
    "Create a 'ClassificationInterpretation' object from 'learner' on 'ds_type'."
    return ClassificationInterpretation.from_learner(learn, ds_type=ds_type)

Learner.interpret = _learner_interpret