import torch import torch.nn as nn import torch.nn.functional as F def exists(val): return val is not None def initialize_weights(module): for m in module.modules(): if isinstance(m, nn.Linear): nn.init.xavier_normal_(m.weight) if m.bias is not None: m.bias.data.zero_() """ Attention Network with Sigmoid Gating (3 fc layers) args: L: input feature dimension D: hidden layer dimension dropout: whether to use dropout (p = 0.25) n_classes: number of classes """ class Attn_Net_Gated(nn.Module): def __init__(self, L = 1024, D = 256, n_tasks = 1): super(Attn_Net_Gated, self).__init__() self.attention_a = nn.Sequential(nn.Linear(L, D), nn.Tanh(), nn.Dropout(0.25)) self.attention_b = nn.Sequential(nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(0.25)) self.attention_c = nn.Linear(D, n_tasks) def forward(self, x): a = self.attention_a(x) b = self.attention_b(x) A = a.mul(b) A = self.attention_c(A) # N x n_classes return A, x """ Code borrow from: https://github.com/mahmoodlab/TOAD args: gate: whether to use gating in attention network size_args: size config of attention network dropout: whether to use dropout in attention network n_classes: number of classes """ class DeepAttnMIL(nn.Module): def __init__(self, input_dim = 1024, size_arg = "big", n_classes = 2): super(DeepAttnMIL, self).__init__() self.size_dict = {"small": [input_dim, 512, 256], "big": [input_dim, 512, 384]} size = self.size_dict[size_arg] self.attention_net = nn.Sequential( nn.Linear(size[0], size[1]), nn.ReLU(), nn.Dropout(0.25), Attn_Net_Gated(L = size[1], D = size[2], n_tasks = 1)) self.classifier = nn.Linear(size[1], n_classes) initialize_weights(self) def forward(self, h, return_features=False, attention_only=False): A, h = self.attention_net(h) A = torch.transpose(A, 1, 0) if attention_only: return A[0] A = F.softmax(A, dim=1) M = torch.mm(A, h) if return_features: return M logits = self.classifier(M) return logits