Spaces:
Running
Running
File size: 5,443 Bytes
aea73e2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# -*- coding: utf-8 -*-
# PyTorch Implementation of Attention Modules
#
# Implementation based on: https://github.com/mahmoodlab/CLAM
# @ Fabian Hörst, fabian.hoerst@uk-essen.de
# Institute for Artifical Intelligence in Medicine,
# University Medicine Essen
from typing import Tuple
import torch
import torch.nn as nn
class Attention(nn.Module):
"""Basic Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
Args:
in_features (int, optional): Input shape of attention module. Defaults to 1024.
attention_features (int, optional): Number of attention features. Defaults to 128.
num_classes (int, optional): Number of output classes. Defaults to 2.
dropout (bool, optional): If True, dropout is used. Defaults to False.
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
Needs to be between 0.0 and 1.0. Defaults to 0.25.
"""
def __init__(
self,
in_features: int = 1024,
attention_features: int = 128,
num_classes: int = 2,
dropout: bool = False,
dropout_rate: float = 0.25,
):
super(Attention, self).__init__()
# naming
self.model_name = "AttentionModule"
# set parameter dimensions for attention
self.attention_features = attention_features
self.in_features = in_features
self.num_classes = num_classes
self.dropout = dropout
self.d_rate = dropout_rate
if self.dropout:
assert self.d_rate < 1
self.attention = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Dropout(self.d_rate),
nn.Linear(self.attention_features, self.num_classes),
)
else:
self.attention = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Linear(self.attention_features, self.num_classes),
)
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass, calculating attention scores for given input vector
Args:
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
* Attention-Scores
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
"""
A = self.attention(H)
return A, H
class AttentionGated(nn.Module):
"""Gated Attention module. Compare https://github.com/AMLab-Amsterdam/AttentionDeepMIL
Args:
in_features (int, optional): Input shape of attention module. Defaults to 1024.
attention_features (int, optional): Number of attention features. Defaults to 128.
num_classes (int, optional): Number of output classes. Defaults to 2.
dropout (bool, optional): If True, dropout is used. Defaults to False.
dropout_rate (float, optional): Dropout rate, just applies if dropout parameter is true.
needs to be between 0.0 and 1.0. Defaults to 0.25.
"""
def __init__(
self,
in_features: int = 1024,
attention_features: int = 128,
num_classes: int = 2,
dropout: bool = False,
dropout_rate: float = 0.25,
):
super(AttentionGated, self).__init__()
# naming
self.model_name = "AttentionModuleGated"
# set Parameter dimensions for attention
self.attention_features = attention_features
self.in_features = in_features
self.num_classes = num_classes
self.dropout = dropout
self.d_rate = dropout_rate
if self.dropout:
assert self.d_rate < 1
self.attention_V = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Tanh(),
nn.Dropout(self.d_rate),
)
self.attention_U = nn.Sequential(
nn.Linear(self.in_features, self.attention_features),
nn.Sigmoid(),
nn.Dropout(self.d_rate),
)
self.attention_W = nn.Sequential(
nn.Linear(self.attention_features, self.num_classes)
)
else:
self.attention_V = nn.Sequential(
nn.Linear(self.in_features, self.attention_features), nn.Tanh()
)
self.attention_U = nn.Sequential(
nn.Linear(self.in_features, self.attention_features), nn.Sigmoid()
)
self.attention_W = nn.Sequential(
nn.Linear(self.attention_features, self.num_classes)
)
def forward(self, H: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward pass, calculating attention scores for given input vector
Args:
H (torch.Tensor): Bag of instances. Shape: (Number of instances, Feature-dimensions)
Returns:
Tuple[torch.Tensor, torch.Tensor]:
* Attention-Scores. Shape: (Number of instances)
* H. Shape: Bag of instances. Shape: (Number of instances, Feature-dimensions)
"""
v = self.attention_V(H)
u = self.attention_U(H)
A = self.attention_W(v * u)
return A, H
|