Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
# PyTorch Implementation of Attention Modules | |
# | |
# Implementation based on: https://github.com/mahmoodlab/CLAM | |
# @ Fabian Hörst, fabian.hoerst@uk-essen.de | |
# Institute for Artifical Intelligence in Medicine, | |
# University Medicine Essen | |
from typing import Tuple | |
import torch | |
import torch.nn as nn | |
class Attention(nn.Module): | |
"""Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL | |
Args: | |
in_features (int, optional): Input shape of attention module. Defaults to 1024. | |
attention_features (int, optional): Number of attention features. Defaults to 128. | |
num_classes (int, optional): Number of output classes. Defaults to 2. | |
dropout (bool, optional): If True, dropout is used. Defaults to False. | |
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. | |
Needs to be between 0.0 and 1.0. Defaults to 0.25. | |
""" | |
def __init__( | |
self, | |
in_features: int = 1024, | |
attention_features: int = 128, | |
num_classes: int = 2, | |
dropout: bool = False, | |
dropout_rate: float = 0.25, | |
): | |
super(Attention, self).__init__() | |
# naming | |
self.model_name = "AttentionModule" | |
# set parameter dimensions for attention | |
self.attention_features = attention_features | |
self.in_features = in_features | |
self.num_classes = num_classes | |
self.dropout = dropout | |
self.d_rate = dropout_rate | |
if self.dropout: | |
assert self.d_rate < 1 | |
self.attention = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), | |
nn.Tanh(), | |
nn.Dropout(self.d_rate), | |
nn.Linear(self.attention_features, self.num_classes), | |
) | |
else: | |
self.attention = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), | |
nn.Tanh(), | |
nn.Linear(self.attention_features, self.num_classes), | |
) | |
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Forward pass, calculating attention scores for given input vector | |
Args: | |
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: | |
* Attention-Scores | |
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) | |
""" | |
A = self.attention(H) | |
return A, H | |
class AttentionGated(nn.Module): | |
"""Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL | |
Args: | |
in_features (int, optional): Input shape of attention module. Defaults to 1024. | |
attention_features (int, optional): Number of attention features. Defaults to 128. | |
num_classes (int, optional): Number of output classes. Defaults to 2. | |
dropout (bool, optional): If True, dropout is used. Defaults to False. | |
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true. | |
needs to be between 0.0 and 1.0. Defaults to 0.25. | |
""" | |
def __init__( | |
self, | |
in_features: int = 1024, | |
attention_features: int = 128, | |
num_classes: int = 2, | |
dropout: bool = False, | |
dropout_rate: float = 0.25, | |
): | |
super(AttentionGated, self).__init__() | |
# naming | |
self.model_name = "AttentionModuleGated" | |
# set Parameter dimensions for attention | |
self.attention_features = attention_features | |
self.in_features = in_features | |
self.num_classes = num_classes | |
self.dropout = dropout | |
self.d_rate = dropout_rate | |
if self.dropout: | |
assert self.d_rate < 1 | |
self.attention_V = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), | |
nn.Tanh(), | |
nn.Dropout(self.d_rate), | |
) | |
self.attention_U = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), | |
nn.Sigmoid(), | |
nn.Dropout(self.d_rate), | |
) | |
self.attention_W = nn.Sequential( | |
nn.Linear(self.attention_features, self.num_classes) | |
) | |
else: | |
self.attention_V = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), nn.Tanh() | |
) | |
self.attention_U = nn.Sequential( | |
nn.Linear(self.in_features, self.attention_features), nn.Sigmoid() | |
) | |
self.attention_W = nn.Sequential( | |
nn.Linear(self.attention_features, self.num_classes) | |
) | |
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Forward pass, calculating attention scores for given input vector | |
Args: | |
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions) | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: | |
* Attention-Scores. Shape: (Number of instances) | |
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions) | |
""" | |
v = self.attention_V(H) | |
u = self.attention_U(H) | |
A = self.attention_W(v * u) | |
return A, H | |