import torch import torch.nn as nn import torch.nn.functional as F from torch.nn.functional import binary_cross_entropy_with_logits import math from transformers import PreTrainedModel from .configuration_flowformer import FlowformerConfig class MAB(nn.Module): """ Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825. """ def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False): super(MAB, self).__init__() self.dim_V = dim_V self.num_heads = num_heads self.fc_q = nn.Linear(dim_Q, dim_V) self.fc_k = nn.Linear(dim_K, dim_V) self.fc_v = nn.Linear(dim_K, dim_V) if ln: self.ln0 = nn.LayerNorm(dim_V) self.ln1 = nn.LayerNorm(dim_V) self.fc_o = nn.Linear(dim_V, dim_V) def forward(self, Q, K): Q = self.fc_q(Q) K, V = self.fc_k(K), self.fc_v(K) dim_split = self.dim_V // self.num_heads Q_ = torch.cat(Q.split(dim_split, 2), dim=0) K_ = torch.cat(K.split(dim_split, 2), dim=0) V_ = torch.cat(V.split(dim_split, 2), dim=0) A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2) O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2) O = O if getattr(self, 'ln0', None) is None else self.ln0(O) O = O + F.relu(self.fc_o(O)) O = O if getattr(self, 'ln1', None) is None else self.ln1(O) return O class ISAB(nn.Module): """ The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825. """ def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False): super(ISAB, self).__init__() self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out)) nn.init.xavier_uniform_(self.I) self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln) self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln) def forward(self, X): H = self.mab0(self.I.repeat(X.size(0), 1, 1), X) return self.mab1(X, H) class Flowformer(PreTrainedModel): r""" This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and behavior. Parameters: config (`FlowformerConfig`): Model configuration class with all the parameters of the model. Initializing with a config file does not load the weights associated with the model, only the configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights. """ config_class = FlowformerConfig def __init__(self, config: FlowformerConfig): super().__init__(config) # Load config dim_input = config.dim_input dim_hidden = config.dim_hidden num_heads = config.num_heads num_inds = config.num_inds hidden_layers = config.hidden_layers layer_norm = config.layer_norm dim_output = 1 self._markers = config.markers # Define encoder enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)] for _ in range(1, hidden_layers): enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm)) enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number self.enc = nn.Sequential(*enc_layers) # Define decoder dec_layers = [nn.Linear(dim_input, dim_output)] self.dec = nn.Sequential(*dec_layers) def markers(self): return self._markers def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None): r""" Args: tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`): The sample used as a basis for the prediction. labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*): Optional ground truth lables for computing the loss. markers (`list` of length `num_markers`): The list of markers in the same order as the last dimension of the input tensor. """ B, L, M = tensor.shape if markers is not None: assert len(markers) == M, "last dimension of input must be equal to number of markers" zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device) valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)] idx = [self.markers().index(m) for m in valid_markers] zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model tensor = zeros enc_out = self.enc(tensor) output = self.dec(enc_out)[:,:,0] if labels is not None: return { 'loss': binary_cross_entropy_with_logits(output, labels), 'logits': output, 'prediction': torch.where(output > 0, 1, 0) } else: return { 'logits': output, 'prediction': torch.where(output > 0, 1, 0) }