File size: 5,431 Bytes
f710746
 
 
 
 
 
 
 
 
 
 
 
 
173f9b2
f710746
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
173f9b2
f710746
 
 
 
 
 
 
 
 
 
 
 
133b0ea
 
 
173f9b2
 
 
 
133b0ea
173f9b2
 
133b0ea
cf97376
 
173f9b2
f710746
 
 
 
 
 
 
 
 
 
173f9b2
f710746
 
 
 
 
 
 
 
 
 
 
 
173f9b2
02fe640
f710746
173f9b2
133b0ea
 
 
 
 
 
 
 
 
f710746
 
02fe640
f710746
173f9b2
 
 
f710746
 
 
 
 
 
 
 
 
173f9b2
 
f710746
 
 
173f9b2
 
f710746
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.functional import binary_cross_entropy_with_logits
import math
from transformers import PreTrainedModel
from .configuration_flowformer import FlowformerConfig


class MAB(nn.Module):
    """
    Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
    """
    def __init__(self, dim_Q: int, dim_K: int, dim_V: int, num_heads: int, ln: int=False):
        super(MAB, self).__init__()

        self.dim_V = dim_V
        self.num_heads = num_heads
        self.fc_q = nn.Linear(dim_Q, dim_V)
        self.fc_k = nn.Linear(dim_K, dim_V)
        self.fc_v = nn.Linear(dim_K, dim_V)

        if ln:
            self.ln0 = nn.LayerNorm(dim_V)
            self.ln1 = nn.LayerNorm(dim_V)
        self.fc_o = nn.Linear(dim_V, dim_V)

    def forward(self, Q, K):
        Q = self.fc_q(Q)
        K, V = self.fc_k(K), self.fc_v(K)

        dim_split = self.dim_V // self.num_heads
        Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
        K_ = torch.cat(K.split(dim_split, 2), dim=0)
        V_ = torch.cat(V.split(dim_split, 2), dim=0)

        A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
        O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
        O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
        O = O + F.relu(self.fc_o(O))
        O = O if getattr(self, 'ln1', None) is None else self.ln1(O)

        return O


class ISAB(nn.Module):
    """
    The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
    """
    def __init__(self, dim_in: int, dim_out: int, num_heads: int, num_inds: int, ln: bool=False):
        super(ISAB, self).__init__()

        self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
        nn.init.xavier_uniform_(self.I)
        self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
        self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)

    def forward(self, X):
        H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)

        return self.mab1(X, H)
    

class Flowformer(PreTrainedModel):
    r"""
    This model is a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass. Use it
    as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage and
    behavior.
    Parameters:
        config (`FlowformerConfig`): Model configuration class with all the parameters of the model.
            Initializing with a config file does not load the weights associated with the model, only the
            configuration. Check out the [`~PreTrainedModel.from_pretrained`] method to load the model weights.
    """
    config_class = FlowformerConfig
    
    def __init__(self, config: FlowformerConfig):
        super().__init__(config)

        # Load config
        dim_input = config.dim_input
        dim_hidden = config.dim_hidden
        num_heads = config.num_heads
        num_inds = config.num_inds
        hidden_layers = config.hidden_layers
        layer_norm = config.layer_norm
        dim_output = 1  
        self._markers = config.markers

        # Define encoder
        enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
        for _ in range(1, hidden_layers):
            enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
        enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
        self.enc = nn.Sequential(*enc_layers)
 
        # Define decoder
        dec_layers = [nn.Linear(dim_input, dim_output)]
        self.dec = nn.Sequential(*dec_layers)

    def markers(self):
        return self._markers

    def forward(self, tensor: torch.Tensor, labels: torch.Tensor=None, markers: list=None):
        r"""
            Args:
                tensor (`torch.FloatTensor` of shape `(batch_size, sequence_length, num_markers)`):
                    The sample used as a basis for the prediction.
                labels (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
                    Optional ground truth lables for computing the loss.
                markers (`list` of length `num_markers`):
                    The list of markers in the same order as the last dimension of the input tensor.
        """
        B, L, M = tensor.shape
        if markers is not None:
            assert len(markers) == M, "last dimension of input must be equal to number of markers"

            zeros = torch.zeros((B, L, len(self.markers())), device=tensor.device)
            valid_markers = [m for m in markers if m in set(self.markers()).intersection(markers)]
            idx = [self.markers().index(m) for m in valid_markers]
            zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
            tensor = zeros

        enc_out = self.enc(tensor)        
        output = self.dec(enc_out)[:,:,0]

        if labels is not None:
            return {
                'loss': binary_cross_entropy_with_logits(output, labels),
                'logits': output,
                'prediction': torch.where(output > 0, 1, 0)
            }
        else:
            return {
                'logits': output,
                'prediction': torch.where(output > 0, 1, 0)
            }