|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
def exists(val): |
|
return val is not None |
|
|
|
def initialize_weights(module): |
|
for m in module.modules(): |
|
if isinstance(m, nn.Linear): |
|
nn.init.xavier_normal_(m.weight) |
|
if m.bias is not None: |
|
m.bias.data.zero_() |
|
|
|
""" |
|
Attention Network with Sigmoid Gating (3 fc layers) |
|
args: |
|
L: input feature dimension |
|
D: hidden layer dimension |
|
dropout: whether to use dropout (p = 0.25) |
|
n_classes: number of classes |
|
""" |
|
class Attn_Net_Gated(nn.Module): |
|
|
|
def __init__(self, L = 1024, D = 256, n_tasks = 1): |
|
super(Attn_Net_Gated, self).__init__() |
|
self.attention_a = nn.Sequential(nn.Linear(L, D), nn.Tanh(), nn.Dropout(0.25)) |
|
self.attention_b = nn.Sequential(nn.Linear(L, D), nn.Sigmoid(), nn.Dropout(0.25)) |
|
self.attention_c = nn.Linear(D, n_tasks) |
|
|
|
def forward(self, x): |
|
a = self.attention_a(x) |
|
b = self.attention_b(x) |
|
A = a.mul(b) |
|
A = self.attention_c(A) |
|
return A, x |
|
|
|
|
|
""" |
|
Code borrow from: https://github.com/mahmoodlab/TOAD |
|
|
|
args: |
|
gate: whether to use gating in attention network |
|
size_args: size config of attention network |
|
dropout: whether to use dropout in attention network |
|
n_classes: number of classes |
|
""" |
|
|
|
class DeepAttnMIL(nn.Module): |
|
|
|
def __init__(self, input_dim = 1024, size_arg = "big", n_classes = 2): |
|
super(DeepAttnMIL, self).__init__() |
|
self.size_dict = {"small": [input_dim, 512, 256], "big": [input_dim, 512, 384]} |
|
size = self.size_dict[size_arg] |
|
|
|
self.attention_net = nn.Sequential( |
|
nn.Linear(size[0], size[1]), |
|
nn.ReLU(), |
|
nn.Dropout(0.25), |
|
Attn_Net_Gated(L = size[1], D = size[2], n_tasks = 1)) |
|
|
|
self.classifier = nn.Linear(size[1], n_classes) |
|
|
|
initialize_weights(self) |
|
|
|
def forward(self, h, return_features=False, attention_only=False): |
|
A, h = self.attention_net(h) |
|
A = torch.transpose(A, 1, 0) |
|
if attention_only: |
|
return A[0] |
|
|
|
A = F.softmax(A, dim=1) |
|
M = torch.mm(A, h) |
|
|
|
if return_features: |
|
return M |
|
|
|
logits = self.classifier(M) |
|
|
|
return logits |
|
|
|
|