flowformer / model_flowformer.py
matth's picture
Add reference to config class
cf97376 verified
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy_with_logits
import math
from transformers import PreTrainedModel
from .configuration_flowformer import FlowformerConfig
class MAB(nn.Module):
"""
Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
"""
def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False):
super(MAB, self).__init__()
self.dim_V = dim_V
self.num_heads = num_heads
self.fc_q = nn.Linear(dim_Q, dim_V)
self.fc_k = nn.Linear(dim_K, dim_V)
self.fc_v = nn.Linear(dim_K, dim_V)
if ln:
self.ln0 = nn.LayerNorm(dim_V)
self.ln1 = nn.LayerNorm(dim_V)
self.fc_o = nn.Linear(dim_V, dim_V)
def forward(self, Q, K):
Q = self.fc_q(Q)
K, V = self.fc_k(K), self.fc_v(K)
dim_split = self.dim_V // self.num_heads
Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
K_ = torch.cat(K.split(dim_split, 2), dim=0)
V_ = torch.cat(V.split(dim_split, 2), dim=0)
A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
O = O + F.relu(self.fc_o(O))
O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
return O
class ISAB(nn.Module):
"""
The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
"""
def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False):
super(ISAB, self).__init__()
self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
nn.init.xavier_uniform_(self.I)
self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
def forward(self, X):
H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
return self.mab1(X, H)
class Flowformer(PreTrainedModel):
r"""
This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
behavior.
Parameters:
config (`FlowformerConfig`): Model configuration class with all the parameters of the model.
Initializing with a config file does not load the weights associated with the model, only the
configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
config_class = FlowformerConfig
def __init__(self, config: FlowformerConfig):
super().__init__(config)
# Load config
dim_input = config.dim_input
dim_hidden = config.dim_hidden
num_heads = config.num_heads
num_inds = config.num_inds
hidden_layers = config.hidden_layers
layer_norm = config.layer_norm
dim_output = 1
self._markers = config.markers
# Define encoder
enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
for _ in range(1, hidden_layers):
enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
self.enc = nn.Sequential(*enc_layers)
# Define decoder
dec_layers = [nn.Linear(dim_input, dim_output)]
self.dec = nn.Sequential(*dec_layers)
def markers(self):
return self._markers
def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
r"""
Args:
tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`):
The sample used as a basis for the prediction.
labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
Optional ground truth lables for computing the loss.
markers (`list` of length `num_markers`):
The list of markers in the same order as the last dimension of the input tensor.
"""
B, L, M = tensor.shape
if markers is not None:
assert len(markers) == M, "last dimension of input must be equal to number of markers"
zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
idx = [self.markers().index(m) for m in valid_markers]
zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
tensor = zeros
enc_out = self.enc(tensor)
output = self.dec(enc_out)[:,:,0]
if labels is not None:
return {
'loss': binary_cross_entropy_with_logits(output, labels),
'logits': output,
'prediction': torch.where(output > 0, 1, 0)
}
else:
return {
'logits': output,
'prediction': torch.where(output > 0, 1, 0)
}