matth commited on
Commit
f710746
1 Parent(s): 2000e52

Upload Flowformer

Browse files
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "Flowformer"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_flowformer.FlowformerConfig",
7
+ "AutoModel": "model_flowformer.Flowformer"
8
+ },
9
+ "dim_hidden": 32,
10
+ "dim_input": 11,
11
+ "hidden_layers": 3,
12
+ "layer_norm": true,
13
+ "markers": [
14
+ "TIME",
15
+ "FSC-A",
16
+ "FSC-W",
17
+ "SSC-A",
18
+ "CD20",
19
+ "CD10",
20
+ "CD45",
21
+ "CD34",
22
+ "CD19",
23
+ "CD38",
24
+ "SY41"
25
+ ],
26
+ "num_heads": 4,
27
+ "num_inds": 16,
28
+ "torch_dtype": "float32",
29
+ "transformers_version": "4.28.1"
30
+ }
configuration_flowformer.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+
3
+ class FlowformerConfig(PretrainedConfig):
4
+ def __init__(self,
5
+ dim_hidden: int=32, # dim_hidden must be divisible by num_heads i.e. dim_hidden%num_heads = 0
6
+ num_heads: int=4,
7
+ num_inds: int=16,
8
+ hidden_layers: int=3,
9
+ layer_norm: bool=True,
10
+ dim_input: int=11,
11
+ markers: list=["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"],
12
+ **kwargs
13
+ ):
14
+ assert dim_input == len(markers), "dim_input must be equal to the number of markers"
15
+
16
+ self.dim_hidden = dim_hidden
17
+ self.num_heads = num_heads
18
+ self.num_inds = num_inds
19
+ self.hidden_layers = hidden_layers
20
+ self.layer_norm = layer_norm
21
+ self.dim_input = dim_input
22
+ self.markers = markers
23
+ super().__init__(**kwargs)
model_flowformer.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn.functional import binary_cross_entropy_with_logits
5
+ import math
6
+ from transformers import PreTrainedModel
7
+ from .configuration_flowformer import FlowformerConfig
8
+
9
+
10
+ class MAB(nn.Module):
11
+ """
12
+ Multihead attention Block (MAB) from https://arxiv.org/abs/1810.00825.
13
+ """
14
+ def __init__(self, dim_Q, dim_K, dim_V, num_heads, ln=False):
15
+ super(MAB, self).__init__()
16
+
17
+ self.dim_V = dim_V
18
+ self.num_heads = num_heads
19
+ self.fc_q = nn.Linear(dim_Q, dim_V)
20
+ self.fc_k = nn.Linear(dim_K, dim_V)
21
+ self.fc_v = nn.Linear(dim_K, dim_V)
22
+
23
+ if ln:
24
+ self.ln0 = nn.LayerNorm(dim_V)
25
+ self.ln1 = nn.LayerNorm(dim_V)
26
+ self.fc_o = nn.Linear(dim_V, dim_V)
27
+
28
+ def forward(self, Q, K):
29
+ Q = self.fc_q(Q)
30
+ K, V = self.fc_k(K), self.fc_v(K)
31
+
32
+ dim_split = self.dim_V // self.num_heads
33
+ Q_ = torch.cat(Q.split(dim_split, 2), dim=0)
34
+ K_ = torch.cat(K.split(dim_split, 2), dim=0)
35
+ V_ = torch.cat(V.split(dim_split, 2), dim=0)
36
+
37
+ A = torch.softmax(Q_.bmm(K_.transpose(1,2))/math.sqrt(self.dim_V), 2)
38
+ O = torch.cat((Q_ + A.bmm(V_)).split(Q.size(0), 0), 2)
39
+ O = O if getattr(self, 'ln0', None) is None else self.ln0(O)
40
+ O = O + F.relu(self.fc_o(O))
41
+ O = O if getattr(self, 'ln1', None) is None else self.ln1(O)
42
+
43
+ return O
44
+
45
+
46
+ class ISAB(nn.Module):
47
+ """
48
+ The Induced Set Attention Block (ISAB) from https://arxiv.org/abs/1810.00825.
49
+ """
50
+ def __init__(self, dim_in, dim_out, num_heads, num_inds, ln=False):
51
+ super(ISAB, self).__init__()
52
+
53
+ self.I = nn.Parameter(torch.Tensor(1, num_inds, dim_out))
54
+ nn.init.xavier_uniform_(self.I)
55
+ self.mab0 = MAB(dim_out, dim_in, dim_out, num_heads, ln=ln)
56
+ self.mab1 = MAB(dim_in, dim_out, dim_out, num_heads, ln=ln)
57
+
58
+ def forward(self, X):
59
+ H = self.mab0(self.I.repeat(X.size(0), 1, 1), X)
60
+
61
+ return self.mab1(X, H)
62
+
63
+ class Flowformer(PreTrainedModel):
64
+ def __init__(self, config):
65
+ super().__init__(config)
66
+
67
+ # Load config
68
+ dim_input = config.dim_input
69
+ dim_hidden = config.dim_hidden
70
+ num_heads = config.num_heads
71
+ num_inds = config.num_inds
72
+ hidden_layers = config.hidden_layers
73
+ layer_norm = config.layer_norm
74
+ dim_output = 1
75
+ self._pretrained_markers = config.markers or ["TIME", "FSC-A", "FSC-W", "SSC-A", "CD20", "CD10", "CD45", "CD34", "CD19", "CD38", "SY41"]
76
+
77
+ # Define encoder
78
+ enc_layers = [ISAB(dim_input, dim_hidden, num_heads, num_inds, ln=layer_norm)]
79
+ for _ in range(1, hidden_layers):
80
+ enc_layers.append(ISAB(dim_hidden, dim_hidden, num_heads, num_inds, ln=layer_norm))
81
+ enc_layers.append(ISAB(dim_hidden, dim_input, 1, num_inds, ln=layer_norm)) # num_heads == 1 because dim_input can be a prime number
82
+ self.enc = nn.Sequential(*enc_layers)
83
+
84
+ # Define decoder
85
+ dec_layers = [nn.Linear(dim_input, dim_output)]
86
+ self.dec = nn.Sequential(*dec_layers)
87
+
88
+ def pretrained_markers(self):
89
+ return self._pretrained_markers
90
+
91
+ def forward(self, tensor, labels=None, markers: list=None):
92
+ B, L, M = tensor.shape
93
+ if markers is not None:
94
+ assert len(markers) == M, "Number of markers in x and markers must be identical"
95
+
96
+ zeros = torch.zeros((B, L, len(self._pretrained_markers)), device=tensor.device)
97
+ valid_markers = [m for m in markers if m in set(self._pretrained_markers).intersection(markers)]
98
+ idx = [self._pretrained_markers.index(m) for m in valid_markers]
99
+ zeros[:, :, idx] = tensor # select only the markers that are in the pretrained model
100
+ tensor = zeros
101
+
102
+ enc_out = self.enc(tensor)
103
+ output = self.dec(enc_out)[:,:,0]
104
+
105
+ if labels is not None:
106
+ return {
107
+ 'loss': binary_cross_entropy_with_logits(output, labels),
108
+ 'logits': output
109
+ }
110
+ else:
111
+ return {
112
+ 'logits': output
113
+ }
114
+
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:055b27977924a2b82a5842c34673de48fa8478eb110374b6066508469b2c9c35
3
+ size 139813