File size: 5,431 Bytes
f710746 173f9b2 f710746 173f9b2 f710746 133b0ea 173f9b2 133b0ea 173f9b2 133b0ea cf97376 173f9b2 f710746 173f9b2 f710746 173f9b2 02fe640 f710746 173f9b2 133b0ea f710746 02fe640 f710746 173f9b2 f710746 173f9b2 f710746 173f9b2 f710746 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy_with_logits
import math
from transformers import PreTrainedModel
from .configuration_flowformer import FlowformerConfig
class MAB(nn.Module):
"""
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
"""
def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
K_ = torch.cat(K.split(dim_split, 2), dim=0)
V_ = torch.cat(V.split(dim_split, 2), dim=0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class ISAB(nn.Module):
"""
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
"""
def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class Flowformer(PreTrainedModel):
r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config (`FlowformerConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
config_class = FlowformerConfig
def __init__(self, config: FlowformerConfig):
super().__init__(config)
# Load config
dim_input = config.dim_input
dim_hidden = config.dim_hidden
num_heads = config.num_heads
num_inds = config.num_inds
hidden_layers = config.hidden_layers
layer_norm = config.layer_norm
dim_output = 1
self._markers = config.markers
# Define encoder
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
for _ in range(1, hidden_layers):
enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
self.enc = nn.Sequential(*enc_layers)
# Define decoder
dec_layers = [nn.Linear(dim_input, dim_output)]
self.dec = nn.Sequential(*dec_layers)
def markers(self):
return self._markers
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
r"""
Args:
tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`):
The sample used as a basis for the prediction.
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Optional ground truth lables for computing the loss.
markers (`list` of length `num_markers`):
The list of markers in the same order as the last dimension of the input tensor.
"""
B, L, M = tensor.shape
if markers is not None:
assert len(markers) == M, "last dimension of input must be equal to number of markers"
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
idx = [self.markers().index(m) for m in valid_markers]
zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
tensor = zeros
enc_out = self.enc(tensor)
output = self.dec(enc_out)[:,:,0]
if labels is not None:
return {
'loss': binary_cross_entropy_with_logits(output, labels),
'logits': output,
'prediction': torch.where(output > 0, 1, 0)
}
else:
return {
'logits': output,
'prediction': torch.where(output > 0, 1, 0)
}
|